Commit 120b463c authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

support a layer that saves outputs

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/417

This diff adds a layer `CachedLayer` which is meant to be used with dynamic mixin. This layer runs the original module and clones the output into a dictionary provided by the user.

The main use case is in distillation where we dynamically mixin these layers to the layers that the user wants to compute various losses.

See subsequent diffs to get integration with distillation.

Reviewed By: Minione

Differential Revision: D40285573

fbshipit-source-id: 2058deff8b96f63aebd1e9b9933a5352b5197111
parent 0f27e90f
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# distillation algorithms in configs: DISILLATION_ALAGORITHM, DISTILLATION_HELPER # distillation algorithms in configs: DISILLATION_ALAGORITHM, DISTILLATION_HELPER
from abc import abstractmethod from abc import abstractmethod
from typing import List from typing import Dict, List, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -376,3 +376,62 @@ def _validate_teacher_config(cfg: CN) -> None: ...@@ -376,3 +376,62 @@ def _validate_teacher_config(cfg: CN) -> None:
raise ValueError( raise ValueError(
f"Unrecognized DISTILLATION.TEACHER.TYPE: {cfg.DISTILLATION.TEACHER.TYPE}" f"Unrecognized DISTILLATION.TEACHER.TYPE: {cfg.DISTILLATION.TEACHER.TYPE}"
) )
class CachedLayer(nn.Module):
"""Cached layer records the output of a layer
This is meant to be used with dynamic mixin. The layer overrides the forward
of the original layer such that the input and the output is the same but
the output of the layer is saved to a dict that can be retrieved later
"""
def dynamic_mixin_init(
self,
label: str,
cache: Dict[
str, Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]
],
):
self.label = label
self.cache = cache
def remove_dynamic_mixin(self):
del self.label
del self.cache
def forward(self, *args, **kwargs):
"""Run the original layer and save the output
We clone the output to avoid the case where a subsequent module
runs an inplace operation. However, this limits what the cache
can support as we can only run clone on a tensor so we need to
check the type of the output.
Support of the output type is limited to:
* tensor
* List[tensor]
* Dict[str, tensor]
"""
output = super().forward(*args, **kwargs)
if isinstance(output, torch.Tensor):
self.cache[self.label] = output.clone()
elif isinstance(output, List):
cloned_output = []
for x in output:
if isinstance(x, torch.Tensor):
cloned_output.append(x.clone())
else:
raise ValueError(f"Unexpected type to save: {type(x)}")
self.cache[self.label] = cloned_output
elif isinstance(output, Dict):
cloned_output = {}
for k, v in output.items():
if isinstance(v, torch.Tensor):
cloned_output[k] = v.clone()
else:
raise ValueError(f"Unexpected type to save: {type(v)}")
self.cache[self.label] = cloned_output
else:
raise ValueError(f"Unexpected type to save: {type(output)}")
return output
...@@ -15,6 +15,7 @@ from d2go.modeling.distillation import ( ...@@ -15,6 +15,7 @@ from d2go.modeling.distillation import (
_set_device, _set_device,
add_distillation_configs, add_distillation_configs,
BaseDistillationHelper, BaseDistillationHelper,
CachedLayer,
DistillationModelingHook, DistillationModelingHook,
ExampleDistillationHelper, ExampleDistillationHelper,
LabelDistillation, LabelDistillation,
...@@ -32,6 +33,7 @@ from d2go.utils.testing import helper ...@@ -32,6 +33,7 @@ from d2go.utils.testing import helper
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.mixin import dynamic_mixin
class DivideInputBy2(nn.Module): class DivideInputBy2(nn.Module):
...@@ -53,6 +55,12 @@ class DivideInputDictBy2(nn.Module): ...@@ -53,6 +55,12 @@ class DivideInputDictBy2(nn.Module):
return torch.stack(output) return torch.stack(output)
class DivideInputBy2OutputDict(nn.Module):
def forward(self, batched_inputs: List):
"""Divide all targets by 2 and return dict output"""
return {i: x / 2.0 for i, x in enumerate(batched_inputs)}
class AddOne(nn.Module): class AddOne(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -231,6 +239,45 @@ class TestDistillation(unittest.TestCase): ...@@ -231,6 +239,45 @@ class TestDistillation(unittest.TestCase):
model = _set_device(model, device) model = _set_device(model, device)
self.assertEqual(model.device, device) self.assertEqual(model.device, device)
def test_cached_layer_tensor(self):
"""Check cached layer saves layer output"""
model = AddOne()
cache = {}
dynamic_mixin(
model,
CachedLayer,
init_dict={"label": "test_layer", "cache": cache},
)
input = torch.randn(1)
output = model(input)
self.assertEqual(output, cache["test_layer"])
def test_cached_layer_list(self):
"""Check cached layer saves list"""
model = DivideInputBy2()
cache = {}
dynamic_mixin(
model,
CachedLayer,
init_dict={"label": "test_layer", "cache": cache},
)
input = [torch.randn(1) for _ in range(2)]
output = model(input)
self.assertEqual(output, cache["test_layer"])
def test_cached_layer_dict(self):
"""Check cached layer saves dict"""
model = DivideInputBy2OutputDict()
cache = {}
dynamic_mixin(
model,
CachedLayer,
init_dict={"label": "test_layer", "cache": cache},
)
input = [torch.randn(1) for _ in range(2)]
output = model(input)
self.assertEqual(output, cache["test_layer"])
class TestPseudoLabeler(unittest.TestCase): class TestPseudoLabeler(unittest.TestCase):
def test_noop(self): def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment