Commit dc176d58 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

support pytorch checkpoint as teacher model using config

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/371

In a previous iteration of this diff, we were specifying the teacher model in the same config as the student model, something like:
```
# config.py
MODEL:
  FBNET_V2:
  ...
DISTILLATION:
  TEACHER:
    MODEL:
      FBNET_V2:
      ...
      WEIGHTS: /path/to/teacher/weights
...
```

This leads to some oddities in the code, like we have to have a default config that adds all the required keys in the distillation teacher model.

In this diff, we just let the user supply a teacher config (and optionally runner_name and overwrite opts) and use the supplied runner to build the model:
```
# new_config.py
MODEL:
  FBNET_V2:
...
DISTILLATION:
  TEACHER:
    CONFIG_FNAME: /path/to/teacher/config
    RUNNER_NAME:
...
```

This should make it very easy to specify the teacher as the user could potentially just reuse the trained_config generated in d2go.

Reviewed By: newstzpz

Differential Revision: D37640041

fbshipit-source-id: 088a636c96f98279c9a04e32d1674f703451aec3
parent fce2a186
...@@ -30,13 +30,22 @@ from mobile_cv.common.misc.mixin import dynamic_mixin, remove_dynamic_mixin ...@@ -30,13 +30,22 @@ from mobile_cv.common.misc.mixin import dynamic_mixin, remove_dynamic_mixin
def add_distillation_configs(_C: CN) -> None: def add_distillation_configs(_C: CN) -> None:
"""Add default parameters to config""" """Add default parameters to config
The TEACHER.CONFIG field allows us to build a PyTorch model using an
existing config. We can build any model that is normally supported by
D2Go (e.g., FBNet) because we just use the same config
"""
_C.DISTILLATION = CN() _C.DISTILLATION = CN()
_C.DISTILLATION.ALGORITHM = "LabelDistillation" _C.DISTILLATION.ALGORITHM = "LabelDistillation"
_C.DISTILLATION.HELPER = "BaseDistillationHelper" _C.DISTILLATION.HELPER = "BaseDistillationHelper"
_C.DISTILLATION.TEACHER = CN() _C.DISTILLATION.TEACHER = CN()
_C.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME = "" _C.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME = ""
_C.DISTILLATION.TEACHER.DEVICE = "" _C.DISTILLATION.TEACHER.DEVICE = ""
_C.DISTILLATION.TEACHER.TYPE = "torchscript"
_C.DISTILLATION.TEACHER.CONFIG_FNAME = ""
_C.DISTILLATION.TEACHER.RUNNER_NAME = "d2go.runner.GeneralizedRCNNRunner"
_C.DISTILLATION.TEACHER.OVERWRITE_OPTS = []
class PseudoLabeler: class PseudoLabeler:
...@@ -293,22 +302,54 @@ class DistillationModelingHook(mh.ModelingHook): ...@@ -293,22 +302,54 @@ class DistillationModelingHook(mh.ModelingHook):
return model return model
def _build_teacher(cfg): def _build_teacher(cfg) -> nn.Module:
"""Create teacher using config settings """Create teacher using config settings
Only supports torchscript Supports torchscript or creating pytorch model using config.
""" """
assert ( _validate_teacher_config(cfg)
cfg.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME if cfg.DISTILLATION.TEACHER.TYPE == "torchscript":
), "Only supports teacher loaded as torchscript" with PathManager.open(cfg.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME, "rb") as f:
model = torch.jit.load(f)
torchscript_fname = cfg.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME elif cfg.DISTILLATION.TEACHER.TYPE == "config":
with PathManager.open(torchscript_fname, "rb") as f: from d2go.runner import import_runner
ts = torch.jit.load(f) from d2go.setup import create_cfg_from_cli
teacher_cfg = create_cfg_from_cli(
cfg.DISTILLATION.TEACHER.CONFIG_FNAME,
cfg.DISTILLATION.TEACHER.OVERWRITE_OPTS,
cfg.DISTILLATION.TEACHER.RUNNER_NAME,
)
runner = import_runner(cfg.DISTILLATION.TEACHER.RUNNER_NAME)()
model = runner.build_model(teacher_cfg, eval_only=True)
else:
raise ValueError(f"Unexpected teacher type: {cfg.DISTILLATION.TEACHER.TYPE}")
# move teacher to same device as student unless specified # move teacher to same device as student unless specified
device = torch.device(cfg.DISTILLATION.TEACHER.DEVICE or cfg.MODEL.DEVICE) device = torch.device(cfg.DISTILLATION.TEACHER.DEVICE or cfg.MODEL.DEVICE)
ts = ts.to(device) model = model.to(device)
ts.device = device model.device = device
ts.eval() model.eval()
return ts return model
def _validate_teacher_config(cfg: CN) -> None:
"""We support torchscript or PyTorch checkpoint as teacher models
If torchscript, need:
* torchscript_filename
If config, needs:
* config_fname
"""
if cfg.DISTILLATION.TEACHER.TYPE == "torchscript":
assert (
cfg.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME
), "Trying to load torchscript model without fname"
elif cfg.DISTILLATION.TEACHER.TYPE == "config":
assert (
cfg.DISTILLATION.TEACHER.CONFIG_FNAME
), "Trying to load D2Go teacher model without config"
else:
raise ValueError(
f"Unrecognized DISTILLATION.TEACHER.TYPE: {cfg.DISTILLATION.TEACHER.TYPE}"
)
...@@ -24,8 +24,12 @@ from d2go.modeling.distillation import ( ...@@ -24,8 +24,12 @@ from d2go.modeling.distillation import (
from d2go.registry.builtin import ( from d2go.registry.builtin import (
DISTILLATION_ALGORITHM_REGISTRY, DISTILLATION_ALGORITHM_REGISTRY,
DISTILLATION_HELPER_REGISTRY, DISTILLATION_HELPER_REGISTRY,
META_ARCH_REGISTRY,
) )
from d2go.runner.default_runner import BaseRunner
from d2go.utils.testing import helper from d2go.utils.testing import helper
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
...@@ -65,6 +69,16 @@ class TestLabeler(PseudoLabeler): ...@@ -65,6 +69,16 @@ class TestLabeler(PseudoLabeler):
return self.teacher(x) return self.teacher(x)
@META_ARCH_REGISTRY.register()
class TestMetaArchAddRand(nn.Module):
def __init__(self, cfg):
super().__init__()
self.weight = nn.Parameter(torch.rand(1))
def forward(self, x):
return x + self.weight
@DISTILLATION_HELPER_REGISTRY.register() @DISTILLATION_HELPER_REGISTRY.register()
class TestHelper(BaseDistillationHelper): class TestHelper(BaseDistillationHelper):
def get_pseudo_labeler(self): def get_pseudo_labeler(self):
...@@ -97,6 +111,7 @@ def _get_default_cfg(): ...@@ -97,6 +111,7 @@ def _get_default_cfg():
cfg.MODEL.DEVICE = "cpu" cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch" cfg.MODEL.META_ARCHITECTURE = "TestArch"
add_distillation_configs(cfg) add_distillation_configs(cfg)
# model_ema.add_model_ema_configs(cfg)
cfg.DISTILLATION.ALGORITHM = "LabelDistillation" cfg.DISTILLATION.ALGORITHM = "LabelDistillation"
cfg.DISTILLATION.HELPER = "BaseDistillationHelper" cfg.DISTILLATION.HELPER = "BaseDistillationHelper"
cfg.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME = "" cfg.DISTILLATION.TEACHER.TORCHSCRIPT_FNAME = ""
...@@ -111,8 +126,11 @@ class TestDistillation(unittest.TestCase): ...@@ -111,8 +126,11 @@ class TestDistillation(unittest.TestCase):
add_distillation_configs(cfg) add_distillation_configs(cfg)
self.assertTrue(isinstance(cfg.DISTILLATION.TEACHER, CfgNode)) self.assertTrue(isinstance(cfg.DISTILLATION.TEACHER, CfgNode))
def test_build_teacher(self): # check teacher model config is clone of student model
"""Check can build teacher using config""" self.assertEqual(cfg.DISTILLATION.TEACHER.CONFIG_FNAME, "")
def test_build_teacher_torchscript(self):
"""Check can build teacher using torchscript fname in config"""
# create torchscript # create torchscript
model = DivideInputBy2() model = DivideInputBy2()
traced_model = torch.jit.trace(model, torch.randn(5)) traced_model = torch.jit.trace(model, torch.randn(5))
...@@ -130,7 +148,7 @@ class TestDistillation(unittest.TestCase): ...@@ -130,7 +148,7 @@ class TestDistillation(unittest.TestCase):
torch.testing.assert_close(torch.Tensor(output), gt) torch.testing.assert_close(torch.Tensor(output), gt)
@helper.skip_if_no_gpu @helper.skip_if_no_gpu
def test_build_teacher_gpu(self): def test_build_teacher_torchscript_gpu(self):
"""Check teacher moved to cuda""" """Check teacher moved to cuda"""
model = AddOne() model = AddOne()
traced_model = torch.jit.trace(model, torch.randn(5)) traced_model = torch.jit.trace(model, torch.randn(5))
...@@ -148,6 +166,27 @@ class TestDistillation(unittest.TestCase): ...@@ -148,6 +166,27 @@ class TestDistillation(unittest.TestCase):
output = teacher(batched_inputs) output = teacher(batched_inputs)
torch.testing.assert_close(torch.Tensor(output), gt) torch.testing.assert_close(torch.Tensor(output), gt)
def test_build_teacher_config(self):
"""Check build pytorch model using config"""
# build model
cfg = _get_default_cfg()
cfg.MODEL.META_ARCHITECTURE = "TestMetaArchAddRand"
gt_model = BaseRunner().build_model(cfg)
with make_temp_directory("tmp") as output_dir:
# save model
checkpointer = DetectionCheckpointer(gt_model, save_dir=output_dir)
checkpointer.save("checkpoint")
cfg.MODEL.WEIGHTS = f"{output_dir}/checkpoint.pth"
config_fname = f"{output_dir}/config.yaml"
with PathManager.open(config_fname, "w") as f:
f.write(cfg.dump())
# load model and compare to gt
cfg.DISTILLATION.TEACHER.TYPE = "config"
cfg.DISTILLATION.TEACHER.CONFIG_FNAME = config_fname
model = _build_teacher(cfg)
self.assertEqual(gt_model.weight, model.weight)
class TestPseudoLabeler(unittest.TestCase): class TestPseudoLabeler(unittest.TestCase):
def test_noop(self): def test_noop(self):
...@@ -302,3 +341,47 @@ class TestDistillationModelingHook(unittest.TestCase): ...@@ -302,3 +341,47 @@ class TestDistillationModelingHook(unittest.TestCase):
model.train() model.train()
output = model(batched_inputs) output = model(batched_inputs)
torch.testing.assert_close(output, gt) torch.testing.assert_close(output, gt)
class DistillationMiscTests(unittest.TestCase):
def test_teacher_outside_updated_parameters(self):
"""
Check that teacher values are ignored when updating student
The teacher can often be referenced in the mixed in model. A common
example is when the teacher is an attributed of the distillation
helper.
=> DistillationModel.distillation_helper.teacher
This raises the question of whether the teacher model will be affected
by calls to the mixed in model:
DisillationModel.train() => does teacher switch to training?
setup_qat(DistillationModel) => will fuse occur on the teacher modules?
The answer to these questions should be no as we want the teacher to remain static
during training (unless specified). This is the case as long as teacher is an
attribute of a non-module class (e.g., distillation_helper). This is because
modules are registered in PyTorch as part of __setattr__. __setattr__ only checks
if the value is a module or parameter. If the value is an object
(e.g., distillation_helper) which contains modules, these modules are ignored.
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.register_parameter
This unittest builds the teacher model and checks that only the student
parameter is registered.
"""
cfg = _get_default_cfg()
cfg.MODEL.META_ARCHITECTURE = "TestMetaArchAddRand"
prebuilt_teacher = BaseRunner().build_model(cfg)
with make_temp_directory("tmp") as output_dir:
checkpointer = DetectionCheckpointer(prebuilt_teacher, save_dir=output_dir)
checkpointer.save("checkpoint")
cfg.MODEL.WEIGHTS = f"{output_dir}/checkpoint.pth"
config_fname = f"{output_dir}/config.yaml"
with PathManager.open(config_fname, "w") as f:
f.write(cfg.dump())
cfg.DISTILLATION.TEACHER.TYPE = "config"
cfg.DISTILLATION.TEACHER.CONFIG_FNAME = config_fname
cfg.DISTILLATION.HELPER = "TestHelper"
cfg.MODEL.MODELING_HOOKS = ["DistillationModelingHook"]
distilled_model = BaseRunner().build_model(cfg)
self.assertEqual(len(list(distilled_model.parameters())), 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment