Commit dece58ba authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

support caching tuples

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/432

We support caching of tuples since they behave similarly to lists

Reviewed By: XiaoliangDai

Differential Revision: D41483876

fbshipit-source-id: 9d741074f8e2335ddd737ae3f1bdb288910f5564
parent 150db2d1
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Union from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -635,12 +635,13 @@ class CachedLayer(nn.Module): ...@@ -635,12 +635,13 @@ class CachedLayer(nn.Module):
Support of the output type is limited to: Support of the output type is limited to:
* tensor * tensor
* List[tensor] * List[tensor]
* Tuple[tensor]
* Dict[str, tensor] * Dict[str, tensor]
""" """
output = super().forward(*args, **kwargs) output = super().forward(*args, **kwargs)
if isinstance(output, torch.Tensor): if isinstance(output, torch.Tensor):
self.cache[self.label] = output.clone() self.cache[self.label] = output.clone()
elif isinstance(output, List): elif isinstance(output, List) or isinstance(output, Tuple):
cloned_output = [] cloned_output = []
for x in output: for x in output:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
......
...@@ -368,6 +368,19 @@ class TestDistillation(unittest.TestCase): ...@@ -368,6 +368,19 @@ class TestDistillation(unittest.TestCase):
output = model(input) output = model(input)
self.assertEqual(output, cache["test_layer"]) self.assertEqual(output, cache["test_layer"])
def test_cached_layer_tuple(self):
"""Check cached layer saves list"""
model = DivideInputBy2()
cache = {}
dynamic_mixin(
model,
CachedLayer,
init_dict={"label": "test_layer", "cache": cache},
)
input = (torch.randn(1) for _ in range(2))
output = model(input)
self.assertEqual(output, cache["test_layer"])
def test_cached_layer_dict(self): def test_cached_layer_dict(self):
"""Check cached layer saves dict""" """Check cached layer saves dict"""
model = DivideInputBy2OutputDict() model = DivideInputBy2OutputDict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment