Commit 150db2d1 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

algorithm

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/431

Add a generic domain adaptation algorithm. This algorithm:
* gets domain0 data out of the dataloader
* runs domain0 data into the model and saves target layer output
* gets domain1 data of the dataloader
* runs domain1 data into the model and saves target layer output
* runs domain adaptation loss on domain0, domain1 outputs
* combines losses using model training iteration

This diffs adds `get_preprocess_domain0_input` and `get_preprocess_domain1_input` to the distillation helper. These are functions that the user can use to convert the dataloader output to something that will be used by the model (e.g., pull the domain0 or domain1 key out of a dataloader that returns a dict).

Differential Revision: D40970724

fbshipit-source-id: fff050fbe864654fa6cb0df927f6843855ec1c14
parent c4860c5b
...@@ -209,6 +209,26 @@ class BaseDistillationHelper: ...@@ -209,6 +209,26 @@ class BaseDistillationHelper:
""" """
return lambda x: x return lambda x: x
def get_preprocess_domain0_input(self) -> Callable:
"""Return a function that allows user to modify the dataloader output
before passing to the model
The output of this function will be directly passed to the model.
Example use cases include:
* dataloader returns a dictionary of real and synthetic images. use
this function to return only the real data (domain0) to the model
"""
return lambda x: x
def get_preprocess_domain1_input(self) -> Callable:
"""Same as get_preprocess_domain0_input but returns domain1 inputs
Example:
* dataloader returns a dictionary of real and synthetic images. use
this function to return only synthetic data (domain1) to the model
"""
return lambda x: x
@DISTILLATION_HELPER_REGISTRY.register() @DISTILLATION_HELPER_REGISTRY.register()
class ExampleDistillationHelper(BaseDistillationHelper): class ExampleDistillationHelper(BaseDistillationHelper):
...@@ -364,6 +384,72 @@ class KnowledgeDistillation(BaseDistillationAlgorithm): ...@@ -364,6 +384,72 @@ class KnowledgeDistillation(BaseDistillationAlgorithm):
return losses return losses
@DISTILLATION_ALGORITHM_REGISTRY.register()
class DomainAdaptation(BaseDistillationAlgorithm):
"""Domain adaptation applies loss over the inputs of domain0 and domain1"""
def dynamic_mixin_init(self, distillation_helper: BaseDistillationHelper):
super().dynamic_mixin_init(distillation_helper)
self._domain0_preprocess_input = (
self.distillation_helper.get_preprocess_domain0_input()
)
self._domain1_preprocess_input = (
self.distillation_helper.get_preprocess_domain1_input()
)
ll = self.distillation_helper.get_layer_losses(self)
self._layer_losses = register_layer_losses_and_to_device(ll, self)
# we ignore the cache dict returned by record_layers as we need to
# manually set the dict at every iteration in the forward
self._domain0_cache = {}
self._domain1_cache = {}
# since domain adaptation uses the same model in both domains, we
# only need to add CachedLayers once
record_layers(self, [ll.layer0 for ll in self._layer_losses])
self._combine_losses = self.distillation_helper.get_combine_losses()
def remove_dynamic_mixin(self):
super().remove_dynamic_mixin()
unrecord_layers(self, [ll.layer0 for ll in self._layer_losses])
del self._layer_losses
del self._domain0_cache
del self._domain1_cache
del self._domain0_preprocess_input
del self._domain1_preprocess_input
del self._combine_losses
def forward(self, batched_inputs: List):
"""Run domain0 input, domain1 input and compute losses"""
domain0_input = self._domain0_preprocess_input(batched_inputs)
if not self.training:
return super().forward(domain0_input)
# run domain0
set_cache_dict(self, self._domain0_cache)
domain0_losses = super().forward(domain0_input)
# run domain1
domain1_input = self._domain1_preprocess_input(batched_inputs)
set_cache_dict(self, self._domain1_cache)
domain1_losses = super().forward(domain1_input)
# calculate losses
domain_adaptation_losses = compute_layer_losses(
self._layer_losses, self._domain0_cache, self._domain1_cache
)
# combine losses
# note we currently assume that the loss combiner uses training iteration
losses = self._combine_losses(
domain0_losses,
domain1_losses,
domain_adaptation_losses,
getattr(self, "_training_iteration", -1),
)
return losses
@MODELING_HOOK_REGISTRY.register() @MODELING_HOOK_REGISTRY.register()
class DistillationModelingHook(mh.ModelingHook): class DistillationModelingHook(mh.ModelingHook):
"""Wrapper hook that allows us to apply different distillation algorithms """Wrapper hook that allows us to apply different distillation algorithms
......
...@@ -19,6 +19,7 @@ from d2go.modeling.distillation import ( ...@@ -19,6 +19,7 @@ from d2go.modeling.distillation import (
compute_layer_losses, compute_layer_losses,
DefaultLossCombiner, DefaultLossCombiner,
DistillationModelingHook, DistillationModelingHook,
DomainAdaptation,
ExampleDistillationHelper, ExampleDistillationHelper,
get_default_kd_image_classification_layer_losses, get_default_kd_image_classification_layer_losses,
KnowledgeDistillation, KnowledgeDistillation,
...@@ -167,6 +168,31 @@ class TestHelper(BaseDistillationHelper): ...@@ -167,6 +168,31 @@ class TestHelper(BaseDistillationHelper):
} }
class TestDAHelper(BaseDistillationHelper):
def get_preprocess_domain0_input(self):
return lambda x: x["real"]
def get_preprocess_domain1_input(self):
return lambda x: x["synthetic"]
def get_layer_losses(self, model=None):
return [
LayerLossMetadata(
loss=SimpleAdd(),
name="add",
layer0="layer0",
layer1="layer0",
)
]
def get_combine_losses(self):
return lambda d0, d1, da, ta: {
"real": d0["output"] * 0.1,
"synthetic": d1["output"] * 0.5,
"add": da["add"] * 10.0,
}
class Noop(nn.Module): class Noop(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
...@@ -485,7 +511,11 @@ class TestDistillationAlgorithm(unittest.TestCase): ...@@ -485,7 +511,11 @@ class TestDistillationAlgorithm(unittest.TestCase):
def test_registry(self): def test_registry(self):
"""Check distillation teacher in registry""" """Check distillation teacher in registry"""
for algorithm in ["LabelDistillation", "KnowledgeDistillation"]: for algorithm in [
"LabelDistillation",
"KnowledgeDistillation",
"DomainAdaptation",
]:
self.assertTrue(algorithm in DISTILLATION_ALGORITHM_REGISTRY) self.assertTrue(algorithm in DISTILLATION_ALGORITHM_REGISTRY)
def test_label_distillation_inference(self): def test_label_distillation_inference(self):
...@@ -565,6 +595,54 @@ class TestDistillationAlgorithm(unittest.TestCase): ...@@ -565,6 +595,54 @@ class TestDistillationAlgorithm(unittest.TestCase):
for module in model.modules(): for module in model.modules():
self.assertFalse(hasattr(module, "cache")) self.assertFalse(hasattr(module, "cache"))
def test_da_inference(self):
"""Check inference defaults to student (and preprocessing)"""
distillation_helper = TestDAHelper(cfg=CfgNode(), teacher=nn.Identity())
model = AddLayers()
dynamic_mixin(
model,
DomainAdaptation,
init_dict={"distillation_helper": distillation_helper},
)
model.eval()
input = {"real": torch.randn(1), "synthetic": torch.randn(1)}
output = model(input)
self.assertEqual(output, input["real"] + 3.0)
def test_da_train(self):
"""Check train pass results in updated loss output"""
distillation_helper = TestDAHelper(cfg=CfgNode(), teacher=nn.Identity())
model = AddLayers()
dynamic_mixin(
model,
DomainAdaptation,
init_dict={"distillation_helper": distillation_helper},
)
model.train()
input = {"real": torch.randn(1), "synthetic": torch.randn(1)}
output = model(input)
self.assertEqual(
output,
{
"real": (input["real"] + 3.0) * 0.1,
"synthetic": (input["synthetic"] + 3.0) * 0.5,
"add": ((input["real"] + 1.0) + (input["synthetic"] + 1.0)) * 10.0,
},
)
def test_da_remove_dynamic_mixin(self):
"""Check removing dynamic mixin removes cached layers"""
distillation_helper = TestHelper(cfg=CfgNode(), teacher=nn.Identity())
model = AddLayers()
dynamic_mixin(
model,
DomainAdaptation,
init_dict={"distillation_helper": distillation_helper},
)
remove_dynamic_mixin(model)
for module in model.modules():
self.assertFalse(hasattr(module, "cache"))
class TestDistillationModelingHook(unittest.TestCase): class TestDistillationModelingHook(unittest.TestCase):
_build_teacher_ref = "d2go.modeling.distillation._build_teacher" _build_teacher_ref = "d2go.modeling.distillation._build_teacher"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment