"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f1d052c5b8a4401e0e60352ddfeaafbb203e5bbf"
Commit 53c4c2c1 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

add a helper to record layers in a model

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/418

This diff adds a function that can be used to add `CachedLayers` to a model. Function iterates over named modules and dynamically mixes in `CachedLayer` to target modules.

This diff adds a function to remove the cached layers.

Reviewed By: Minione

Differential Revision: D40285806

fbshipit-source-id: 3137d19927d8fb9ec924a77c9085aea29fe94d5e
parent 120b463c
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# distillation algorithms in configs: DISILLATION_ALAGORITHM, DISTILLATION_HELPER # distillation algorithms in configs: DISILLATION_ALAGORITHM, DISTILLATION_HELPER
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, List, Union
from typing import Dict, List, Set, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -435,3 +436,27 @@ class CachedLayer(nn.Module): ...@@ -435,3 +436,27 @@ class CachedLayer(nn.Module):
else: else:
raise ValueError(f"Unexpected type to save: {type(output)}") raise ValueError(f"Unexpected type to save: {type(output)}")
return output return output
def record_layers(model: nn.Module, layer_names: Set[str]) -> Dict[str, torch.Tensor]:
"""Save the outputs of layer_names in model
Iterates over all named layers in model, applies cached layer to layers in
layer_names. Returns dict which is used by the cached layers.
"""
cache = {}
for name, module in model.named_modules():
if name in layer_names:
dynamic_mixin(
module,
CachedLayer,
init_dict={"label": name, "cache": cache},
)
return cache
def unrecord_layers(model: nn.Module, layer_names: Set[str]) -> None:
"""Remove cached layers based on the layer_names"""
for name, module in model.named_modules():
if name in layer_names:
remove_dynamic_mixin(module)
...@@ -21,7 +21,9 @@ from d2go.modeling.distillation import ( ...@@ -21,7 +21,9 @@ from d2go.modeling.distillation import (
LabelDistillation, LabelDistillation,
NoopPseudoLabeler, NoopPseudoLabeler,
PseudoLabeler, PseudoLabeler,
record_layers,
RelabelTargetInBatch, RelabelTargetInBatch,
unrecord_layers,
) )
from d2go.registry.builtin import ( from d2go.registry.builtin import (
DISTILLATION_ALGORITHM_REGISTRY, DISTILLATION_ALGORITHM_REGISTRY,
...@@ -33,7 +35,7 @@ from d2go.utils.testing import helper ...@@ -33,7 +35,7 @@ from d2go.utils.testing import helper
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.mixin import dynamic_mixin from mobile_cv.common.misc.mixin import dynamic_mixin, remove_dynamic_mixin
class DivideInputBy2(nn.Module): class DivideInputBy2(nn.Module):
...@@ -74,6 +76,20 @@ class AddOne(nn.Module): ...@@ -74,6 +76,20 @@ class AddOne(nn.Module):
return self.weight.device return self.weight.device
class AddLayers(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = AddOne()
self.layer1 = AddOne()
self.layer2 = AddOne()
def forward(self, x):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
return x
class TestLabeler(PseudoLabeler): class TestLabeler(PseudoLabeler):
def __init__(self, teacher): def __init__(self, teacher):
self.teacher = teacher self.teacher = teacher
...@@ -278,6 +294,25 @@ class TestDistillation(unittest.TestCase): ...@@ -278,6 +294,25 @@ class TestDistillation(unittest.TestCase):
output = model(input) output = model(input)
self.assertEqual(output, cache["test_layer"]) self.assertEqual(output, cache["test_layer"])
def test_record_layers(self):
"""Check we can record specified layer"""
model = AddLayers()
cache = record_layers(model, ["", "layer0", "layer1", "layer2"])
input = torch.Tensor([0])
output = model(input)
self.assertEqual(cache["layer0"], torch.Tensor([1]))
self.assertEqual(cache["layer1"], torch.Tensor([2]))
self.assertEqual(cache["layer2"], torch.Tensor([3]))
self.assertEqual(cache[""], output)
def test_unrecord_layers(self):
"""Check we can remove a recorded layer"""
model = AddLayers()
_ = record_layers(model, ["", "layer0", "layer1", "layer2"])
unrecord_layers(model, ["", "layer0"])
self.assertFalse(hasattr(model.layer0, "cache"))
class TestPseudoLabeler(unittest.TestCase): class TestPseudoLabeler(unittest.TestCase):
def test_noop(self): def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment