Commit c4860c5b authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

support registering layer losses to model

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/430

We add losses in distillation by instantiating them in the distillation algorithm's init and then running them during the forward pass.

However this has some issues:
* the losses are not registered as a module in the model since they we organize them as a list of layerlossmetadata => this means that things like AMP do not behave as expected
* the losses are not on the same device as the rest of the model since they are created potentially after the model is moved to a new device

This diff solves both of these issues by including a helper function that registers and moves the losses to the same device as the model. `register_layer_losses_and_to_device` takes as input `List[LayerLossMetadata]`, moves the losses to the same device as the model and then registers these losses to the model.

Differential Revision: D41296932

fbshipit-source-id: ae7ae0847bce1b5cc481d838b9cae69cea424f25
parent 909de50d
...@@ -323,7 +323,8 @@ class KnowledgeDistillation(BaseDistillationAlgorithm): ...@@ -323,7 +323,8 @@ class KnowledgeDistillation(BaseDistillationAlgorithm):
self._teacher_preprocess_input = ( self._teacher_preprocess_input = (
self.distillation_helper.get_preprocess_teacher_input() self.distillation_helper.get_preprocess_teacher_input()
) )
self._layer_losses = self.distillation_helper.get_layer_losses(self) ll = self.distillation_helper.get_layer_losses(self)
self._layer_losses = register_layer_losses_and_to_device(ll, self)
self._student_cache = record_layers( self._student_cache = record_layers(
self, [ll.layer0 for ll in self._layer_losses] self, [ll.layer0 for ll in self._layer_losses]
) )
...@@ -676,3 +677,22 @@ class DefaultLossCombiner: ...@@ -676,3 +677,22 @@ class DefaultLossCombiner:
raise ValueError(f"Unexpected weight in loss dict: {k}") raise ValueError(f"Unexpected weight in loss dict: {k}")
output[k] = v * self.name_weight[k] output[k] = v * self.name_weight[k]
return output return output
def register_layer_losses_and_to_device(
layer_losses: List[LayerLossMetadata], model: nn.Module
) -> List[LayerLossMetadata]:
"""Register loss modules in layerlossemtadata to model and move to device"""
registered_losses = []
for ll in layer_losses:
loss_on_device = ll.loss.to(model.device)
model.add_module(ll.name, loss_on_device)
registered_losses.append(
LayerLossMetadata(
loss_on_device,
ll.name,
ll.layer0,
ll.layer1,
)
)
return registered_losses
...@@ -27,6 +27,7 @@ from d2go.modeling.distillation import ( ...@@ -27,6 +27,7 @@ from d2go.modeling.distillation import (
NoopPseudoLabeler, NoopPseudoLabeler,
PseudoLabeler, PseudoLabeler,
record_layers, record_layers,
register_layer_losses_and_to_device,
RelabelTargetInBatch, RelabelTargetInBatch,
set_cache_dict, set_cache_dict,
unrecord_layers, unrecord_layers,
...@@ -97,6 +98,20 @@ class AddLayers(nn.Module): ...@@ -97,6 +98,20 @@ class AddLayers(nn.Module):
return x return x
return {"output": x} return {"output": x}
@property
def device(self):
return self.layer0.weight.device
class SimpleAdd(nn.Module):
def forward(self, x, y):
return x + y
class SimpleMul(nn.Module):
def forward(self, x, y):
return x * y
class TestLabeler(PseudoLabeler): class TestLabeler(PseudoLabeler):
def __init__(self, teacher): def __init__(self, teacher):
...@@ -131,13 +146,13 @@ class TestHelper(BaseDistillationHelper): ...@@ -131,13 +146,13 @@ class TestHelper(BaseDistillationHelper):
def get_layer_losses(self, model=None): def get_layer_losses(self, model=None):
return [ return [
LayerLossMetadata( LayerLossMetadata(
loss=lambda x, y: x + y, loss=SimpleAdd(),
name="add", name="add",
layer0="layer0", layer0="layer0",
layer1="layer0", layer1="layer0",
), ),
LayerLossMetadata( LayerLossMetadata(
loss=lambda x, y: x * y, loss=SimpleMul(),
name="mul", name="mul",
layer0="layer1", layer0="layer1",
layer1="layer1", layer1="layer1",
...@@ -389,6 +404,37 @@ class TestDistillation(unittest.TestCase): ...@@ -389,6 +404,37 @@ class TestDistillation(unittest.TestCase):
torch.testing.assert_close(new_cache["layer2"], torch.Tensor([3])) torch.testing.assert_close(new_cache["layer2"], torch.Tensor([3]))
torch.testing.assert_close(new_cache[""], output) torch.testing.assert_close(new_cache[""], output)
def test_register_layer_losses(self):
"""Check losses can be registered to model"""
model = AddOne()
ll = [
LayerLossMetadata(
loss=SimpleAdd(),
name="mul",
layer0="layer1",
layer1="layer1",
),
]
registered_losses = register_layer_losses_and_to_device(ll, model)
self.assertTrue(hasattr(model, "mul"))
self.assertEqual(model.mul, registered_losses[0].loss)
@helper.skip_if_no_gpu
def test_register_layer_losses_and_to_device(self):
"""Check losses can be registered to model"""
model = AddOne()
model = model.to("cuda")
ll = [
LayerLossMetadata(
loss=AddOne(),
name="mul",
layer0="layer1",
layer1="layer1",
),
]
register_layer_losses_and_to_device(ll, model)
self.assertEqual(model.mul.device, model.device)
class TestPseudoLabeler(unittest.TestCase): class TestPseudoLabeler(unittest.TestCase):
def test_noop(self): def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment