Commit 9ec4f2bf authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

kd algorithm

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/420

Adds knowledge distillation as a generic algorithm that can be used by various projects.

If eval, the algorithm just returns the result of the student model.

If training, the algorithm feeds the input into both the student and teacher model. The user provides a list of `LayerLossMetadata` that provides the layers and losses run on these layers. The algorithm uses dynamic mixin to record the outputs of the relevant layers and compute the losses after both models are run.

We provide student and teacher preprocessing as a placeholder before we support a more generic dataloader which can provide different inputs to the student and teacher (e.g., as of now, if you want to provide the teacher with a larger input then the dataloader should return a large input and the student preprocessing can downsample the input).

We add the following functions as part of the user customizable distillation helper:
* get_teacher => return a teacher that can be used directly by the KD algorithm
* get_layer_losses => return a list of `LayerLossMetadata` that provides the layers and losses
* get_preprocess_student_input => manipulate the output of the dataloader before passing to the student
* get_preprocess_teacher_input => manipulate the output of the dataloader before passing to the teacher
* get_combine_losses => since we may want to weight the student and distillation losses, return a function that can manipulate the loss_dict

Reviewed By: chihyaoma

Differential Revision: D40326412

fbshipit-source-id: 2fb0e818a7d5b120d62fb7aba314ff96cc7e10c5
parent e42112d6
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Set, Union from typing import Callable, Dict, List, Optional, Set, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -150,6 +150,61 @@ class BaseDistillationHelper: ...@@ -150,6 +150,61 @@ class BaseDistillationHelper:
""" """
return NoopPseudoLabeler() return NoopPseudoLabeler()
def get_teacher(self) -> nn.Module:
"""Return a teacher that can be run by the algorithm"""
return self.teacher
def get_layer_losses(
self, model: Optional[nn.Module] = None
) -> List[LayerLossMetadata]:
"""Return losses that are run on layers
Layer parameters may be dependent on model parameters so option to pass
in a model
"""
return []
def get_preprocess_student_input(self) -> Callable:
"""Return a function that allows user to modify the dataloader output
before passing to the student
The output of this function will be directly passed to the student model.
Example use cases include:
* dataloader returns a large image used by the teacher model but the
student model needs a lower resolution version
* dataloader returns both labeled and unlabeled data and the student
requires labeled data
"""
return lambda x: x
def get_preprocess_teacher_input(self) -> Callable:
"""Return a function that allows user to modify dataloader output before
passing to teacher
The output of this function will be directly passed to the teacher model.
"""
return lambda x: x
def get_combine_losses(self) -> Callable:
"""Return a function that takes as input a dictionary of losses and
modifies the loss as required
The default trainer sums the losses at the end so typically this
function is used to change the relative contribution of losses
Example:
def combine_losses(losses)
alpha = 0.1
losses["nll"] *= alpha
losses["kd_loss"] *= (1 - alpha)
return losses
student_losses = {"nll": ...}
student_losses.update({"kl_loss": ...})
losses = combine_losses(student_losses)
"""
return lambda x: x
@DISTILLATION_HELPER_REGISTRY.register() @DISTILLATION_HELPER_REGISTRY.register()
class ExampleDistillationHelper(BaseDistillationHelper): class ExampleDistillationHelper(BaseDistillationHelper):
...@@ -244,6 +299,66 @@ class LabelDistillation(BaseDistillationAlgorithm): ...@@ -244,6 +299,66 @@ class LabelDistillation(BaseDistillationAlgorithm):
return super().forward(new_batched_inputs) return super().forward(new_batched_inputs)
@DISTILLATION_ALGORITHM_REGISTRY.register()
class KnowledgeDistillation(BaseDistillationAlgorithm):
"""Knowledge distillation applies loss over the outputs of the student
and teacher models
"""
def dynamic_mixin_init(self, distillation_helper: BaseDistillationHelper):
"""Note all variables use _ to avoid name conflicts with existing
variable names in the model
Consider adding a check to avoid variable name reuse
"""
super().dynamic_mixin_init(distillation_helper)
self._teacher = WrappedTeacher(self.distillation_helper.get_teacher())
self._student_preprocess_input = (
self.distillation_helper.get_preprocess_student_input()
)
self._teacher_preprocess_input = (
self.distillation_helper.get_preprocess_teacher_input()
)
self._layer_losses = self.distillation_helper.get_layer_losses(self)
self._student_cache = record_layers(
self, [ll.layer0 for ll in self._layer_losses]
)
self._teacher_cache = record_layers(
self._teacher.model, [ll.layer1 for ll in self._layer_losses]
)
self._combine_losses = self.distillation_helper.get_combine_losses()
def remove_dynamic_mixin(self):
super().remove_dynamic_mixin()
unrecord_layers(self, [ll.layer0 for ll in self._layer_losses])
unrecord_layers(self._teacher.model, [ll.layer1 for ll in self._layer_losses])
del self._teacher
del self._layer_losses
del self._student_cache
del self._teacher_cache
del self._student_preprocess_input
del self._teacher_preprocess_input
del self._combine_losses
def forward(self, batched_inputs: List):
"""Run teacher, then student and compute losses"""
student_input = self._student_preprocess_input(batched_inputs)
if not self.training:
return super().forward(student_input)
teacher_input = self._teacher_preprocess_input(batched_inputs)
with torch.no_grad():
self._teacher(teacher_input)
student_losses = super().forward(student_input)
distillation_losses = compute_layer_losses(
self._layer_losses, self._student_cache, self._teacher_cache
)
student_losses.update(distillation_losses)
losses = self._combine_losses(student_losses)
return losses
@MODELING_HOOK_REGISTRY.register() @MODELING_HOOK_REGISTRY.register()
class DistillationModelingHook(mh.ModelingHook): class DistillationModelingHook(mh.ModelingHook):
"""Wrapper hook that allows us to apply different distillation algorithms """Wrapper hook that allows us to apply different distillation algorithms
...@@ -489,3 +604,17 @@ def compute_layer_losses( ...@@ -489,3 +604,17 @@ def compute_layer_losses(
losses[ll.name] = ll.loss(layer0_cache[ll.layer0], layer1_cache[ll.layer1]) losses[ll.name] = ll.loss(layer0_cache[ll.layer0], layer1_cache[ll.layer1])
return losses return losses
class WrappedTeacher:
"""Used to remove the teacher model from the student module list
See: DistillationMiscTests.test_teacher_outside_updated_parameters to get
more details on avoiding adding the teacher as a module
"""
def __init__(self, model: nn.Module):
self.model = model
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
...@@ -19,6 +19,7 @@ from d2go.modeling.distillation import ( ...@@ -19,6 +19,7 @@ from d2go.modeling.distillation import (
compute_layer_losses, compute_layer_losses,
DistillationModelingHook, DistillationModelingHook,
ExampleDistillationHelper, ExampleDistillationHelper,
KnowledgeDistillation,
LabelDistillation, LabelDistillation,
LayerLossMetadata, LayerLossMetadata,
NoopPseudoLabeler, NoopPseudoLabeler,
...@@ -89,7 +90,9 @@ class AddLayers(nn.Module): ...@@ -89,7 +90,9 @@ class AddLayers(nn.Module):
x = self.layer0(x) x = self.layer0(x)
x = self.layer1(x) x = self.layer1(x)
x = self.layer2(x) x = self.layer2(x)
return x if not self.training:
return x
return {"output": x}
class TestLabeler(PseudoLabeler): class TestLabeler(PseudoLabeler):
...@@ -116,6 +119,35 @@ class TestHelper(BaseDistillationHelper): ...@@ -116,6 +119,35 @@ class TestHelper(BaseDistillationHelper):
"""Run teacher model on inputs""" """Run teacher model on inputs"""
return TestLabeler(self.teacher) return TestLabeler(self.teacher)
def get_preprocess_student_input(self):
return lambda x: x + 1
def get_preprocess_teacher_input(self):
return lambda x: x + 2
def get_layer_losses(self, model=None):
return [
LayerLossMetadata(
loss=lambda x, y: x + y,
name="add",
layer0="layer0",
layer1="layer0",
),
LayerLossMetadata(
loss=lambda x, y: x * y,
name="mul",
layer0="layer1",
layer1="layer1",
),
]
def get_combine_losses(self):
return lambda d: {
"output": d["output"] * 0.1,
"add": d["add"] * 0.5,
"mul": d["mul"] * 10.0,
}
class Noop(nn.Module): class Noop(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -363,8 +395,8 @@ class TestDistillationHelper(unittest.TestCase): ...@@ -363,8 +395,8 @@ class TestDistillationHelper(unittest.TestCase):
pseudo_labeler = dh.get_pseudo_labeler() pseudo_labeler = dh.get_pseudo_labeler()
self.assertTrue(isinstance(pseudo_labeler, NoopPseudoLabeler)) self.assertTrue(isinstance(pseudo_labeler, NoopPseudoLabeler))
def test_default_distillation_helper(self): def test_example_distillation_helper(self):
"""Default distillation uses teacher to relabel targets""" """Example distillation uses teacher to relabel targets"""
teacher = Noop() teacher = Noop()
dh = ExampleDistillationHelper(cfg=None, teacher=teacher) dh = ExampleDistillationHelper(cfg=None, teacher=teacher)
pseudo_labeler = dh.get_pseudo_labeler() pseudo_labeler = dh.get_pseudo_labeler()
...@@ -381,7 +413,8 @@ class TestDistillationAlgorithm(unittest.TestCase): ...@@ -381,7 +413,8 @@ class TestDistillationAlgorithm(unittest.TestCase):
def test_registry(self): def test_registry(self):
"""Check distillation teacher in registry""" """Check distillation teacher in registry"""
self.assertTrue("LabelDistillation" in DISTILLATION_ALGORITHM_REGISTRY) for algorithm in ["LabelDistillation", "KnowledgeDistillation"]:
self.assertTrue(algorithm in DISTILLATION_ALGORITHM_REGISTRY)
def test_label_distillation_inference(self): def test_label_distillation_inference(self):
"""Check inference defaults to student """Check inference defaults to student
...@@ -417,6 +450,49 @@ class TestDistillationAlgorithm(unittest.TestCase): ...@@ -417,6 +450,49 @@ class TestDistillationAlgorithm(unittest.TestCase):
sum(output).backward() sum(output).backward()
torch.testing.assert_close(batched_inputs.grad, torch.Tensor([0.5, 0.5])) torch.testing.assert_close(batched_inputs.grad, torch.Tensor([0.5, 0.5]))
def test_kd_inference(self):
"""Check inference defaults to student (and preprocessing)"""
distillation_helper = TestHelper(cfg=CfgNode(), teacher=AddLayers())
model = AddLayers()
dynamic_mixin(
model,
KnowledgeDistillation,
init_dict={"distillation_helper": distillation_helper},
)
model.eval()
input = torch.randn(1)
output = model(input)
torch.testing.assert_close(output, input + 4.0)
def test_kd_train(self):
"""Check train pass results in updated loss output"""
distillation_helper = TestHelper(cfg=CfgNode(), teacher=AddLayers())
model = AddLayers()
dynamic_mixin(
model,
KnowledgeDistillation,
init_dict={"distillation_helper": distillation_helper},
)
model.train()
input = torch.randn(1)
output = model(input)
torch.testing.assert_close(output["output"], (input + 4.0) * 0.1)
torch.testing.assert_close(output["add"], ((input + 2.0) + (input + 3.0)) * 0.5)
torch.testing.assert_close(output["mul"], (input + 3.0) * (input + 4.0) * 10.0)
def test_kd_remove_dynamic_mixin(self):
"""Check removing dynamic mixin removes cached layers"""
distillation_helper = TestHelper(cfg=CfgNode(), teacher=AddLayers())
model = AddLayers()
dynamic_mixin(
model,
KnowledgeDistillation,
init_dict={"distillation_helper": distillation_helper},
)
remove_dynamic_mixin(model)
for module in model.modules():
self.assertFalse(hasattr(module, "cache"))
class TestDistillationModelingHook(unittest.TestCase): class TestDistillationModelingHook(unittest.TestCase):
_build_teacher_ref = "d2go.modeling.distillation._build_teacher" _build_teacher_ref = "d2go.modeling.distillation._build_teacher"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment