Commit dece58ba authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

support caching tuples

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/432

We support caching of tuples since they behave similarly to lists

Reviewed By: XiaoliangDai

Differential Revision: D41483876

fbshipit-source-id: 9d741074f8e2335ddd737ae3f1bdb288910f5564
parent 150db2d1
......@@ -16,7 +16,7 @@
import logging
from abc import abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
......@@ -635,12 +635,13 @@ class CachedLayer(nn.Module):
Support of the output type is limited to:
* tensor
* List[tensor]
* Tuple[tensor]
* Dict[str, tensor]
"""
output = super().forward(*args, **kwargs)
if isinstance(output, torch.Tensor):
self.cache[self.label] = output.clone()
elif isinstance(output, List):
elif isinstance(output, List) or isinstance(output, Tuple):
cloned_output = []
for x in output:
if isinstance(x, torch.Tensor):
......
......@@ -368,6 +368,19 @@ class TestDistillation(unittest.TestCase):
output = model(input)
self.assertEqual(output, cache["test_layer"])
def test_cached_layer_tuple(self):
"""Check cached layer saves list"""
model = DivideInputBy2()
cache = {}
dynamic_mixin(
model,
CachedLayer,
init_dict={"label": "test_layer", "cache": cache},
)
input = (torch.randn(1) for _ in range(2))
output = model(input)
self.assertEqual(output, cache["test_layer"])
def test_cached_layer_dict(self):
"""Check cached layer saves dict"""
model = DivideInputBy2OutputDict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment