Commit 0316fed4 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

add class to keep track of loss metadata and function to compute losses

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/419

This diff adds a metadata class `LayerLossMetadata` to help keep track of the losses we want to compute over layers. The class contains the type of loss, loss name, and layer names.

This diff adds a helper function to iterate over a list of `LayerLossMetadata` and return a dict containing the results.

Reviewed By: chihyaoma

Differential Revision: D40286564

fbshipit-source-id: b269dc63cc90a437ca279379d759c3106016327c
parent 53c4c2c1
......@@ -14,7 +14,7 @@
# distillation algorithms in configs: DISILLATION_ALAGORITHM, DISTILLATION_HELPER
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Set, Union
import torch
......@@ -49,6 +49,14 @@ def add_distillation_configs(_C: CN) -> None:
_C.DISTILLATION.TEACHER.OVERWRITE_OPTS = []
@dataclass
class LayerLossMetadata:
loss: nn.Module
name: str
layer0: str
layer1: str
class PseudoLabeler:
@abstractmethod
def label(self, x):
......@@ -460,3 +468,24 @@ def unrecord_layers(model: nn.Module, layer_names: Set[str]) -> None:
for name, module in model.named_modules():
if name in layer_names:
remove_dynamic_mixin(module)
def compute_layer_losses(
layer_losses: List[LayerLossMetadata],
layer0_cache: Dict[str, torch.Tensor],
layer1_cache: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Compute loss over layers specified in layer_loss
layer0_cache and layer1_cache should contain the data required to compute
the losses specified in layer_loss
"""
losses = {}
for ll in layer_losses:
if ll.layer0 not in layer0_cache:
raise ValueError(f"Missing saved layer {ll.layer0} in layer0_cache")
if ll.layer1 not in layer1_cache:
raise ValueError(f"Missing saved layer {ll.layer1} in layer1_cache")
losses[ll.name] = ll.loss(layer0_cache[ll.layer0], layer1_cache[ll.layer1])
return losses
......@@ -16,9 +16,11 @@ from d2go.modeling.distillation import (
add_distillation_configs,
BaseDistillationHelper,
CachedLayer,
compute_layer_losses,
DistillationModelingHook,
ExampleDistillationHelper,
LabelDistillation,
LayerLossMetadata,
NoopPseudoLabeler,
PseudoLabeler,
record_layers,
......@@ -313,6 +315,22 @@ class TestDistillation(unittest.TestCase):
unrecord_layers(model, ["", "layer0"])
self.assertFalse(hasattr(model.layer0, "cache"))
def test_compute_layer_losses(self):
"""Check iterating over loss dicts"""
layer_losses = [
LayerLossMetadata(
loss=lambda x, y: x + y, name="add", layer0="l00", layer1="l10"
),
LayerLossMetadata(
loss=lambda x, y: x / y, name="div", layer0="l01", layer1="l11"
),
]
layer0_cache = {"l00": torch.randn(1), "l01": torch.randn(1)}
layer1_cache = {"l10": torch.randn(1), "l11": torch.randn(1)}
output = compute_layer_losses(layer_losses, layer0_cache, layer1_cache)
self.assertEqual(output["add"], layer0_cache["l00"] + layer1_cache["l10"])
self.assertEqual(output["div"], layer0_cache["l01"] / layer1_cache["l11"])
class TestPseudoLabeler(unittest.TestCase):
def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment