Commit 0f27e90f authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

update teacher to support models where device is a property

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/416

Distillation assumes teacher model has an attribute "device". Sometimes this attribute is actually a property (e.g., generalizedrcnn) but there is zero guarantee that it exists. We add a helper function to move the model to the device and add this attribute if needed.

Reviewed By: chihyaoma

Differential Revision: D40283954

fbshipit-source-id: 42921653eac8a79499e22edac29aa6aeac016e8a
parent c9211c19
...@@ -335,9 +335,24 @@ def _build_teacher(cfg) -> nn.Module: ...@@ -335,9 +335,24 @@ def _build_teacher(cfg) -> nn.Module:
# move teacher to same device as student unless specified # move teacher to same device as student unless specified
device = torch.device(cfg.DISTILLATION.TEACHER.DEVICE or cfg.MODEL.DEVICE) device = torch.device(cfg.DISTILLATION.TEACHER.DEVICE or cfg.MODEL.DEVICE)
model = _set_device(model, device)
model.eval()
return model
def _set_device(model: nn.Module, device: torch.device) -> nn.Module:
"""Set the device of the model
Some D2Go models have device as a property of the model (e.g., GeneralizedRCNN)
whereas others are missing this attribute which is assumed by distillation
to exist (e.g., we may call teacher.device to move inputs)
This helper function guarantees that the model.device attribute exists
and runs model.to(device)
"""
model = model.to(device) model = model.to(device)
if not hasattr(model, "device"):
model.device = device model.device = device
model.eval()
return model return model
......
...@@ -12,6 +12,7 @@ from d2go.config import CfgNode ...@@ -12,6 +12,7 @@ from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh from d2go.modeling import modeling_hook as mh
from d2go.modeling.distillation import ( from d2go.modeling.distillation import (
_build_teacher, _build_teacher,
_set_device,
add_distillation_configs, add_distillation_configs,
BaseDistillationHelper, BaseDistillationHelper,
DistillationModelingHook, DistillationModelingHook,
...@@ -60,6 +61,10 @@ class AddOne(nn.Module): ...@@ -60,6 +61,10 @@ class AddOne(nn.Module):
def forward(self, x): def forward(self, x):
return x + self.weight return x + self.weight
@property
def device(self):
return self.weight.device
class TestLabeler(PseudoLabeler): class TestLabeler(PseudoLabeler):
def __init__(self, teacher): def __init__(self, teacher):
...@@ -210,6 +215,22 @@ class TestDistillation(unittest.TestCase): ...@@ -210,6 +215,22 @@ class TestDistillation(unittest.TestCase):
model = _build_teacher(cfg) model = _build_teacher(cfg)
self.assertEqual(gt_model.weight, model.weight) self.assertEqual(gt_model.weight, model.weight)
def test_set_device(self):
"""Check teacher device is set"""
# without attr
model = Noop()
self.assertFalse(hasattr(model, "device"))
device = torch.device("cpu")
# without property
model = _set_device(model, device)
self.assertEqual(model.device, device)
# with property
model = AddOne()
model = _set_device(model, device)
self.assertEqual(model.device, device)
class TestPseudoLabeler(unittest.TestCase): class TestPseudoLabeler(unittest.TestCase):
def test_noop(self): def test_noop(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment