Commit 40a6a453 authored by Siddharth Shah's avatar Siddharth Shah Committed by Facebook GitHub Bot
Browse files

Support caching arbitrary nested data structures of Tensor

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/439

As title

Reviewed By: mattcyu1

Differential Revision: D41759804

fbshipit-source-id: 929efa960be570f0fe8543600e012d1bf037ab3b
parent dece58ba
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple, Union from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -34,6 +34,9 @@ from mobile_cv.common.misc.mixin import dynamic_mixin, remove_dynamic_mixin ...@@ -34,6 +34,9 @@ from mobile_cv.common.misc.mixin import dynamic_mixin, remove_dynamic_mixin
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ModelOutput = Union[None, torch.Tensor, Iterable["ModelOutput"]]
def add_distillation_configs(_C: CN) -> None: def add_distillation_configs(_C: CN) -> None:
"""Add default parameters to config """Add default parameters to config
...@@ -613,9 +616,7 @@ class CachedLayer(nn.Module): ...@@ -613,9 +616,7 @@ class CachedLayer(nn.Module):
def dynamic_mixin_init( def dynamic_mixin_init(
self, self,
label: str, label: str,
cache: Dict[ cache: Dict[str, ModelOutput],
str, Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]
],
): ):
self.label = label self.label = label
self.cache = cache self.cache = cache
...@@ -632,44 +633,43 @@ class CachedLayer(nn.Module): ...@@ -632,44 +633,43 @@ class CachedLayer(nn.Module):
can support as we can only run clone on a tensor so we need to can support as we can only run clone on a tensor so we need to
check the type of the output. check the type of the output.
Support of the output type is limited to: Support of the output type is limited to None type and arbitrary nested
* tensor collections of List, Tuple and Dict of tensor.
* List[tensor]
* Tuple[tensor]
* Dict[str, tensor]
""" """
output = super().forward(*args, **kwargs) output = super().forward(*args, **kwargs)
if isinstance(output, torch.Tensor): self.cache[self.label] = CachedLayer._clone(output)
self.cache[self.label] = output.clone() return output
@staticmethod
def _clone(output: ModelOutput) -> ModelOutput:
if output is None:
return None
elif isinstance(output, torch.Tensor):
return output.clone()
elif isinstance(output, List) or isinstance(output, Tuple): elif isinstance(output, List) or isinstance(output, Tuple):
cloned_output = [] cloned_output = []
for x in output: for x in output:
if isinstance(x, torch.Tensor): cloned_output.append(CachedLayer._clone(x))
cloned_output.append(x.clone()) if isinstance(output, Tuple):
else: return tuple(cloned_output)
raise ValueError(f"Unexpected type to save: {type(x)}") return cloned_output
self.cache[self.label] = cloned_output
elif isinstance(output, Dict): elif isinstance(output, Dict):
cloned_output = {} cloned_output = {}
for k, v in output.items(): for k, v in output.items():
if isinstance(v, torch.Tensor): cloned_output[k] = CachedLayer._clone(v)
cloned_output[k] = v.clone() return cloned_output
else:
raise ValueError(f"Unexpected type to save: {type(v)}")
self.cache[self.label] = cloned_output
else: else:
raise ValueError(f"Unexpected type to save: {type(output)}") raise ValueError(f"Unexpected type to save: {type(output)}")
return output
def set_cache_dict(model: nn.Module, cache: Dict) -> None: def set_cache_dict(model: nn.Module, cache: ModelOutput) -> None:
"""Sets the cache in all CachedLayers to input cache""" """Sets the cache in all CachedLayers to input cache"""
for module in model.modules(): for module in model.modules():
if isinstance(module, CachedLayer): if isinstance(module, CachedLayer):
module.cache = cache module.cache = cache
def record_layers(model: nn.Module, layer_names: Set[str]) -> Dict[str, torch.Tensor]: def record_layers(model: nn.Module, layer_names: Set[str]) -> ModelOutput:
"""Save the outputs of layer_names in model """Save the outputs of layer_names in model
Iterates over all named layers in model, applies cached layer to layers in Iterates over all named layers in model, applies cached layer to layers in
...@@ -695,8 +695,8 @@ def unrecord_layers(model: nn.Module, layer_names: Set[str]) -> None: ...@@ -695,8 +695,8 @@ def unrecord_layers(model: nn.Module, layer_names: Set[str]) -> None:
def compute_layer_losses( def compute_layer_losses(
layer_losses: List[LayerLossMetadata], layer_losses: List[LayerLossMetadata],
layer0_cache: Dict[str, torch.Tensor], layer0_cache: ModelOutput,
layer1_cache: Dict[str, torch.Tensor], layer1_cache: ModelOutput,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Compute loss over layers specified in layer_loss """Compute loss over layers specified in layer_loss
......
...@@ -71,6 +71,18 @@ class DivideInputBy2OutputDict(nn.Module): ...@@ -71,6 +71,18 @@ class DivideInputBy2OutputDict(nn.Module):
return {i: x / 2.0 for i, x in enumerate(batched_inputs)} return {i: x / 2.0 for i, x in enumerate(batched_inputs)}
class TimesTable5OutputDict(nn.Module):
def forward(self, batched_inputs: List):
"""Return first five entries of times table for each input with a dict output"""
return {i: [x * i for i in range(1, 6)] for i, x in enumerate(batched_inputs)}
class ConstantStrOutput(nn.Module):
def forward(self, batched_inputs: List):
"""Return some string"""
return "Testing!"
class AddOne(nn.Module): class AddOne(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -394,6 +406,31 @@ class TestDistillation(unittest.TestCase): ...@@ -394,6 +406,31 @@ class TestDistillation(unittest.TestCase):
output = model(input) output = model(input)
self.assertEqual(output, cache["test_layer"]) self.assertEqual(output, cache["test_layer"])
def test_cached_layer_arbitrary(self):
"""Check cached layer saves arbitrary nested data structure"""
model = TimesTable5OutputDict()
cache = {}
dynamic_mixin(
model,
CachedLayer,
init_dict={"label": "test_layer", "cache": cache},
)
input = [torch.randn(1) for _ in range(2)]
output = model(input)
self.assertEqual(output, cache["test_layer"])
def test_cached_layer_unsupported(self):
"""Check cached layer doesn't save unsupported data type like strings"""
model = ConstantStrOutput()
cache = {}
dynamic_mixin(
model,
CachedLayer,
init_dict={"label": "test_layer", "cache": cache},
)
input = [torch.randn(1) for _ in range(2)]
self.assertRaises(ValueError, model, input)
def test_record_layers(self): def test_record_layers(self):
"""Check we can record specified layer""" """Check we can record specified layer"""
model = AddLayers() model = AddLayers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment