You need to sign in or sign up before continuing.
Commit 419974bb authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

add default layer losses and loss combiner

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/421

Add some reasonable defaults when running knowledge distillation
* get_default_kd_image_classification_layer_losses => returns cross entropy loss on the output of the student classification layer and the teacher output (this is what the imagenet distillation uses)
* DefaultLossCombiner => simple function to multiply the losses by some weights

Unsure if these should go in `distillation.py` or a separate place (e.g., defaults or classification)

Reviewed By: chihyaoma

Differential Revision: D40330718

fbshipit-source-id: 5887566d88e3a96d01aca133c51041126b2692cc
parent 0ea6bc1b
......@@ -618,3 +618,41 @@ class WrappedTeacher:
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def get_default_kd_image_classification_layer_losses() -> List[LayerLossMetadata]:
"""Return some typical values used in knowledge distillation
Assumes student model is ImageClassificationMetaArch and teacher model is the same
or a wrapped torchscript model with the same output layer name
"""
return [
LayerLossMetadata(
loss=nn.CrossEntropyLoss(),
name="kd",
layer0="classifier",
layer1="", # use empty layer name to indicate last layer
)
]
class DefaultLossCombiner:
"""Returns a weighted sum of the losses based on the name_weight
name_weight is a dictionary indicating the name of the loss and the
weight associated with that loss
Example:
name_weight = {"nll": 0.1, "kd": 0.9}
"""
def __init__(self, name_weight: Dict[str, float]):
self.name_weight = name_weight
def __call__(self, losses: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
output = {}
for k, v in losses.items():
if k not in self.name_weight:
raise ValueError(f"Unexpected weight in loss dict: {k}")
output[k] = v * self.name_weight[k]
return output
......@@ -17,8 +17,10 @@ from d2go.modeling.distillation import (
BaseDistillationHelper,
CachedLayer,
compute_layer_losses,
DefaultLossCombiner,
DistillationModelingHook,
ExampleDistillationHelper,
get_default_kd_image_classification_layer_losses,
KnowledgeDistillation,
LabelDistillation,
LayerLossMetadata,
......@@ -563,7 +565,7 @@ class TestDistillationModelingHook(unittest.TestCase):
torch.testing.assert_close(output, gt)
class DistillationMiscTests(unittest.TestCase):
class TestDistillationMiscTests(unittest.TestCase):
def test_teacher_outside_updated_parameters(self):
"""
Check that teacher values are ignored when updating student
......@@ -605,3 +607,20 @@ class DistillationMiscTests(unittest.TestCase):
cfg.MODEL.MODELING_HOOKS = ["DistillationModelingHook"]
distilled_model = BaseRunner().build_model(cfg)
self.assertEqual(len(list(distilled_model.parameters())), 1)
class TestDistillationDefaults(unittest.TestCase):
def test_kd_image_classification_layer_losses(self):
"""Check the default returns a list of layerlossmetadata"""
layer_losses = get_default_kd_image_classification_layer_losses()
self.assertTrue(isinstance(layer_losses, List))
self.assertTrue(isinstance(layer_losses[0], LayerLossMetadata))
def test_default_loss_combiner(self):
"""Check combiner multiplies loss by weights"""
weights = {"a": torch.randn(1), "b": torch.randn(1)}
combiner = DefaultLossCombiner(weights)
input = {"a": 1.0, "b": 10.0}
output = combiner(input)
torch.testing.assert_close(output["a"], input["a"] * weights["a"])
torch.testing.assert_close(output["b"], input["b"] * weights["b"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment