Commit 909de50d authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

support ignoring teacher

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/429

Add a teacher type called `no_teacher` which can be specified by the user in the case they ignore the teacher (e.g., domain adaptation). Building the teacher just returns a noop (`nn.Identity`)

Differential Revision: D40971788

fbshipit-source-id: fc49ac44224c92806a7be253eefb8454305814eb
parent c2d7dbab
...@@ -459,6 +459,8 @@ def _build_teacher(cfg) -> nn.Module: ...@@ -459,6 +459,8 @@ def _build_teacher(cfg) -> nn.Module:
) )
runner = import_runner(cfg.DISTILLATION.TEACHER.RUNNER_NAME)() runner = import_runner(cfg.DISTILLATION.TEACHER.RUNNER_NAME)()
model = runner.build_model(teacher_cfg, eval_only=True) model = runner.build_model(teacher_cfg, eval_only=True)
elif cfg.DISTILLATION.TEACHER.TYPE == "no_teacher":
model = nn.Identity()
else: else:
raise ValueError(f"Unexpected teacher type: {cfg.DISTILLATION.TEACHER.TYPE}") raise ValueError(f"Unexpected teacher type: {cfg.DISTILLATION.TEACHER.TYPE}")
...@@ -492,6 +494,10 @@ def _validate_teacher_config(cfg: CN) -> None: ...@@ -492,6 +494,10 @@ def _validate_teacher_config(cfg: CN) -> None:
* torchscript_filename * torchscript_filename
If config, needs: If config, needs:
* config_fname * config_fname
Bypass allowed if setting teacher.type = "no_teacher". This can be
useful in cases where we only have the student model
(e.g., domain adaptation)
""" """
if cfg.DISTILLATION.TEACHER.TYPE == "torchscript": if cfg.DISTILLATION.TEACHER.TYPE == "torchscript":
assert ( assert (
...@@ -501,6 +507,8 @@ def _validate_teacher_config(cfg: CN) -> None: ...@@ -501,6 +507,8 @@ def _validate_teacher_config(cfg: CN) -> None:
assert ( assert (
cfg.DISTILLATION.TEACHER.CONFIG_FNAME cfg.DISTILLATION.TEACHER.CONFIG_FNAME
), "Trying to load D2Go teacher model without config" ), "Trying to load D2Go teacher model without config"
elif cfg.DISTILLATION.TEACHER.TYPE == "no_teacher":
pass
else: else:
raise ValueError( raise ValueError(
f"Unrecognized DISTILLATION.TEACHER.TYPE: {cfg.DISTILLATION.TEACHER.TYPE}" f"Unrecognized DISTILLATION.TEACHER.TYPE: {cfg.DISTILLATION.TEACHER.TYPE}"
......
...@@ -253,6 +253,15 @@ class TestDistillation(unittest.TestCase): ...@@ -253,6 +253,15 @@ class TestDistillation(unittest.TestCase):
model = _build_teacher(cfg) model = _build_teacher(cfg)
self.assertEqual(gt_model.weight, model.weight) self.assertEqual(gt_model.weight, model.weight)
def test_build_teacher_none(self):
"""Check that we can ignore building the teacher"""
# build model
cfg = _get_default_cfg()
cfg.MODEL.META_ARCHITECTURE = "TestMetaArchAddRand"
cfg.DISTILLATION.TEACHER.TYPE = "no_teacher"
model = _build_teacher(cfg)
self.assertTrue(isinstance(model, nn.Module))
def test_override_teacher_config_gpu_on_cpu(self): def test_override_teacher_config_gpu_on_cpu(self):
"""Teacher cuda model can be run on cpu if specified in config""" """Teacher cuda model can be run on cpu if specified in config"""
# build model where teacher is specified on gpu but user overrides cpu # build model where teacher is specified on gpu but user overrides cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment