"vscode:/vscode.git/clone" did not exist on "39a3c77e0d4a22de189b02398cf2d003d299b4ae"
Commit 30ac5858 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

set cache in recorded layers

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/433

Distillation uses a module called `CachedLayer` to record the outputs of a layer to an internal dict. This dict is typically initialized by the object itself and any value is overwritten every time the model runs.

However, sometimes we need more than one output run of the layer (e.g., domain adaptation => we run the model on real, then synthetic data and need to use both outputs).

This diff adds a helper to set externally set the cache dict of a model. In other words, we can run `set_cache_dict` on some model to change the dict used by all `CachedLayer` in the model. This allows us to run the model and record some outputs, then change the cache dict and rerun the model to save different outputs.

Differential Revision: D40970577

fbshipit-source-id: 49cb851af49ae193d0c8ac9218e02fdaf4e6587b
parent 19c5392d
...@@ -566,6 +566,13 @@ class CachedLayer(nn.Module): ...@@ -566,6 +566,13 @@ class CachedLayer(nn.Module):
return output return output
def set_cache_dict(model: nn.Module, cache: Dict) -> None:
"""Sets the cache in all CachedLayers to input cache"""
for module in model.modules():
if isinstance(module, CachedLayer):
module.cache = cache
def record_layers(model: nn.Module, layer_names: Set[str]) -> Dict[str, torch.Tensor]: def record_layers(model: nn.Module, layer_names: Set[str]) -> Dict[str, torch.Tensor]:
"""Save the outputs of layer_names in model """Save the outputs of layer_names in model
......
...@@ -28,6 +28,7 @@ from d2go.modeling.distillation import ( ...@@ -28,6 +28,7 @@ from d2go.modeling.distillation import (
PseudoLabeler, PseudoLabeler,
record_layers, record_layers,
RelabelTargetInBatch, RelabelTargetInBatch,
set_cache_dict,
unrecord_layers, unrecord_layers,
) )
from d2go.registry.builtin import ( from d2go.registry.builtin import (
...@@ -365,6 +366,20 @@ class TestDistillation(unittest.TestCase): ...@@ -365,6 +366,20 @@ class TestDistillation(unittest.TestCase):
self.assertEqual(output["add"], layer0_cache["l00"] + layer1_cache["l10"]) self.assertEqual(output["add"], layer0_cache["l00"] + layer1_cache["l10"])
self.assertEqual(output["div"], layer0_cache["l01"] / layer1_cache["l11"]) self.assertEqual(output["div"], layer0_cache["l01"] / layer1_cache["l11"])
def test_set_cache_dict(self):
"""Check we can swap the cache dict used when recording layers"""
model = AddLayers()
cache = record_layers(model, ["", "layer0", "layer1", "layer2"])
new_cache = {}
set_cache_dict(model, new_cache)
input = torch.Tensor([0])
output = model(input)
self.assertEqual(cache, {})
torch.testing.assert_close(new_cache["layer0"], torch.Tensor([1]))
torch.testing.assert_close(new_cache["layer1"], torch.Tensor([2]))
torch.testing.assert_close(new_cache["layer2"], torch.Tensor([3]))
torch.testing.assert_close(new_cache[""], output)
class TestPseudoLabeler(unittest.TestCase): class TestPseudoLabeler(unittest.TestCase):
def test_noop(self): def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment