"examples/vscode:/vscode.git/clone" did not exist on "1c91f460d3e534ed549bf600820d7cc31a0981ff"
Commit 0316fed4 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

add class to keep track of loss metadata and function to compute losses

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/419

This diff adds a metadata class `LayerLossMetadata` to help keep track of the losses we want to compute over layers. The class contains the type of loss, loss name, and layer names.

This diff adds a helper function to iterate over a list of `LayerLossMetadata` and return a dict containing the results.

Reviewed By: chihyaoma

Differential Revision: D40286564

fbshipit-source-id: b269dc63cc90a437ca279379d759c3106016327c
parent 53c4c2c1
......@@ -14,7 +14,7 @@
# distillation algorithms in configs: DISILLATION_ALAGORITHM, DISTILLATION_HELPER
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Set, Union
import torch
......@@ -49,6 +49,14 @@ def add_distillation_configs(_C: CN) -> None:
_C.DISTILLATION.TEACHER.OVERWRITE_OPTS = []
@dataclass
class LayerLossMetadata:
loss: nn.Module
name: str
layer0: str
layer1: str
class PseudoLabeler:
@abstractmethod
def label(self, x):
......@@ -460,3 +468,24 @@ def unrecord_layers(model: nn.Module, layer_names: Set[str]) -> None:
for name, module in model.named_modules():
if name in layer_names:
remove_dynamic_mixin(module)
def compute_layer_losses(
layer_losses: List[LayerLossMetadata],
layer0_cache: Dict[str, torch.Tensor],
layer1_cache: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Compute loss over layers specified in layer_loss
layer0_cache and layer1_cache should contain the data required to compute
the losses specified in layer_loss
"""
losses = {}
for ll in layer_losses:
if ll.layer0 not in layer0_cache:
raise ValueError(f"Missing saved layer {ll.layer0} in layer0_cache")
if ll.layer1 not in layer1_cache:
raise ValueError(f"Missing saved layer {ll.layer1} in layer1_cache")
losses[ll.name] = ll.loss(layer0_cache[ll.layer0], layer1_cache[ll.layer1])
return losses
......@@ -16,9 +16,11 @@ from d2go.modeling.distillation import (
add_distillation_configs,
BaseDistillationHelper,
CachedLayer,
compute_layer_losses,
DistillationModelingHook,
ExampleDistillationHelper,
LabelDistillation,
LayerLossMetadata,
NoopPseudoLabeler,
PseudoLabeler,
record_layers,
......@@ -313,6 +315,22 @@ class TestDistillation(unittest.TestCase):
unrecord_layers(model, ["", "layer0"])
self.assertFalse(hasattr(model.layer0, "cache"))
def test_compute_layer_losses(self):
"""Check iterating over loss dicts"""
layer_losses = [
LayerLossMetadata(
loss=lambda x, y: x + y, name="add", layer0="l00", layer1="l10"
),
LayerLossMetadata(
loss=lambda x, y: x / y, name="div", layer0="l01", layer1="l11"
),
]
layer0_cache = {"l00": torch.randn(1), "l01": torch.randn(1)}
layer1_cache = {"l10": torch.randn(1), "l11": torch.randn(1)}
output = compute_layer_losses(layer_losses, layer0_cache, layer1_cache)
self.assertEqual(output["add"], layer0_cache["l00"] + layer1_cache["l10"])
self.assertEqual(output["div"], layer0_cache["l01"] / layer1_cache["l11"])
class TestPseudoLabeler(unittest.TestCase):
def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment