Commit 419974bb authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

add default layer losses and loss combiner

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/421

Add some reasonable defaults when running knowledge distillation
* get_default_kd_image_classification_layer_losses => returns cross entropy loss on the output of the student classification layer and the teacher output (this is what the imagenet distillation uses)
* DefaultLossCombiner => simple function to multiply the losses by some weights

Unsure if these should go in `distillation.py` or a separate place (e.g., defaults or classification)

Reviewed By: chihyaoma

Differential Revision: D40330718

fbshipit-source-id: 5887566d88e3a96d01aca133c51041126b2692cc
parent 0ea6bc1b
...@@ -618,3 +618,41 @@ class WrappedTeacher: ...@@ -618,3 +618,41 @@ class WrappedTeacher:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs) return self.model(*args, **kwargs)
def get_default_kd_image_classification_layer_losses() -> List[LayerLossMetadata]:
"""Return some typical values used in knowledge distillation
Assumes student model is ImageClassificationMetaArch and teacher model is the same
or a wrapped torchscript model with the same output layer name
"""
return [
LayerLossMetadata(
loss=nn.CrossEntropyLoss(),
name="kd",
layer0="classifier",
layer1="", # use empty layer name to indicate last layer
)
]
class DefaultLossCombiner:
"""Returns a weighted sum of the losses based on the name_weight
name_weight is a dictionary indicating the name of the loss and the
weight associated with that loss
Example:
name_weight = {"nll": 0.1, "kd": 0.9}
"""
def __init__(self, name_weight: Dict[str, float]):
self.name_weight = name_weight
def __call__(self, losses: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
output = {}
for k, v in losses.items():
if k not in self.name_weight:
raise ValueError(f"Unexpected weight in loss dict: {k}")
output[k] = v * self.name_weight[k]
return output
...@@ -17,8 +17,10 @@ from d2go.modeling.distillation import ( ...@@ -17,8 +17,10 @@ from d2go.modeling.distillation import (
BaseDistillationHelper, BaseDistillationHelper,
CachedLayer, CachedLayer,
compute_layer_losses, compute_layer_losses,
DefaultLossCombiner,
DistillationModelingHook, DistillationModelingHook,
ExampleDistillationHelper, ExampleDistillationHelper,
get_default_kd_image_classification_layer_losses,
KnowledgeDistillation, KnowledgeDistillation,
LabelDistillation, LabelDistillation,
LayerLossMetadata, LayerLossMetadata,
...@@ -563,7 +565,7 @@ class TestDistillationModelingHook(unittest.TestCase): ...@@ -563,7 +565,7 @@ class TestDistillationModelingHook(unittest.TestCase):
torch.testing.assert_close(output, gt) torch.testing.assert_close(output, gt)
class DistillationMiscTests(unittest.TestCase): class TestDistillationMiscTests(unittest.TestCase):
def test_teacher_outside_updated_parameters(self): def test_teacher_outside_updated_parameters(self):
""" """
Check that teacher values are ignored when updating student Check that teacher values are ignored when updating student
...@@ -605,3 +607,20 @@ class DistillationMiscTests(unittest.TestCase): ...@@ -605,3 +607,20 @@ class DistillationMiscTests(unittest.TestCase):
cfg.MODEL.MODELING_HOOKS = ["DistillationModelingHook"] cfg.MODEL.MODELING_HOOKS = ["DistillationModelingHook"]
distilled_model = BaseRunner().build_model(cfg) distilled_model = BaseRunner().build_model(cfg)
self.assertEqual(len(list(distilled_model.parameters())), 1) self.assertEqual(len(list(distilled_model.parameters())), 1)
class TestDistillationDefaults(unittest.TestCase):
def test_kd_image_classification_layer_losses(self):
"""Check the default returns a list of layerlossmetadata"""
layer_losses = get_default_kd_image_classification_layer_losses()
self.assertTrue(isinstance(layer_losses, List))
self.assertTrue(isinstance(layer_losses[0], LayerLossMetadata))
def test_default_loss_combiner(self):
"""Check combiner multiplies loss by weights"""
weights = {"a": torch.randn(1), "b": torch.randn(1)}
combiner = DefaultLossCombiner(weights)
input = {"a": 1.0, "b": 10.0}
output = combiner(input)
torch.testing.assert_close(output["a"], input["a"] * weights["a"])
torch.testing.assert_close(output["b"], input["b"] * weights["b"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment