Commit 53c4c2c1 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

add a helper to record layers in a model

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/418

This diff adds a function that can be used to add `CachedLayers` to a model. Function iterates over named modules and dynamically mixes in `CachedLayer` to target modules.

This diff adds a function to remove the cached layers.

Reviewed By: Minione

Differential Revision: D40285806

fbshipit-source-id: 3137d19927d8fb9ec924a77c9085aea29fe94d5e
parent 120b463c
......@@ -14,7 +14,8 @@
# distillation algorithms in configs: DISILLATION_ALAGORITHM, DISTILLATION_HELPER
from abc import abstractmethod
from typing import Dict, List, Union
from typing import Dict, List, Set, Union
import torch
import torch.nn as nn
......@@ -435,3 +436,27 @@ class CachedLayer(nn.Module):
else:
raise ValueError(f"Unexpected type to save: {type(output)}")
return output
def record_layers(model: nn.Module, layer_names: Set[str]) -> Dict[str, torch.Tensor]:
"""Save the outputs of layer_names in model
Iterates over all named layers in model, applies cached layer to layers in
layer_names. Returns dict which is used by the cached layers.
"""
cache = {}
for name, module in model.named_modules():
if name in layer_names:
dynamic_mixin(
module,
CachedLayer,
init_dict={"label": name, "cache": cache},
)
return cache
def unrecord_layers(model: nn.Module, layer_names: Set[str]) -> None:
"""Remove cached layers based on the layer_names"""
for name, module in model.named_modules():
if name in layer_names:
remove_dynamic_mixin(module)
......@@ -21,7 +21,9 @@ from d2go.modeling.distillation import (
LabelDistillation,
NoopPseudoLabeler,
PseudoLabeler,
record_layers,
RelabelTargetInBatch,
unrecord_layers,
)
from d2go.registry.builtin import (
DISTILLATION_ALGORITHM_REGISTRY,
......@@ -33,7 +35,7 @@ from d2go.utils.testing import helper
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.mixin import dynamic_mixin
from mobile_cv.common.misc.mixin import dynamic_mixin, remove_dynamic_mixin
class DivideInputBy2(nn.Module):
......@@ -74,6 +76,20 @@ class AddOne(nn.Module):
return self.weight.device
class AddLayers(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = AddOne()
self.layer1 = AddOne()
self.layer2 = AddOne()
def forward(self, x):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
return x
class TestLabeler(PseudoLabeler):
def __init__(self, teacher):
self.teacher = teacher
......@@ -278,6 +294,25 @@ class TestDistillation(unittest.TestCase):
output = model(input)
self.assertEqual(output, cache["test_layer"])
def test_record_layers(self):
"""Check we can record specified layer"""
model = AddLayers()
cache = record_layers(model, ["", "layer0", "layer1", "layer2"])
input = torch.Tensor([0])
output = model(input)
self.assertEqual(cache["layer0"], torch.Tensor([1]))
self.assertEqual(cache["layer1"], torch.Tensor([2]))
self.assertEqual(cache["layer2"], torch.Tensor([3]))
self.assertEqual(cache[""], output)
def test_unrecord_layers(self):
"""Check we can remove a recorded layer"""
model = AddLayers()
_ = record_layers(model, ["", "layer0", "layer1", "layer2"])
unrecord_layers(model, ["", "layer0"])
self.assertFalse(hasattr(model.layer0, "cache"))
class TestPseudoLabeler(unittest.TestCase):
def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment