Commit 091213f0 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Support unapplying modeling hooks.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/208

Support unapplying modeling hooks.

Reviewed By: tglik

Differential Revision: D35540649

fbshipit-source-id: 60cc5e214282e30b39fc98ba4d58dad2fc6ea086
parent aea25d23
...@@ -20,8 +20,7 @@ def build_model(cfg): ...@@ -20,8 +20,7 @@ def build_model(cfg):
# MODELING_HOOKS key # MODELING_HOOKS key
if hasattr(cfg.MODEL, "MODELING_HOOKS"): if hasattr(cfg.MODEL, "MODELING_HOOKS"):
hook_names = cfg.MODEL.MODELING_HOOKS hook_names = cfg.MODEL.MODELING_HOOKS
mhooks = mh.build_modeling_hooks(cfg, hook_names) model = mh.build_and_apply_modeling_hooks(model, cfg, hook_names)
model = mh.apply_modeling_hooks(model, mhooks)
_log_api_usage("modeling.meta_arch." + meta_arch) _log_api_usage("modeling.meta_arch." + meta_arch)
return model return model
...@@ -40,26 +40,65 @@ class ModelingHook(object): ...@@ -40,26 +40,65 @@ class ModelingHook(object):
@abstractmethod @abstractmethod
def unapply(self, model: torch.nn.Module) -> torch.nn.Module: def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
"""This function will be called when the users called model.get_exportable_model() """This function will be called when the users called model.unapply_modeling_hooks()
after training. The main use case of the function is to remove the changes after training. The main use case of the function is to remove the changes
applied to the model in `apply`. The hooks will be called in reverse order applied to the model in `apply`. The hooks will be called in reverse order
as follow: as follow:
model.get_exportable_model() == model_hook_N.unapply(model) -> model.unapply_modeling_hooks() == model_hook_N.unapply(model) ->
model_hook_N-1.unapply(model) -> ... -> model_hook_1.unapply(model) model_hook_N-1.unapply(model) -> ... -> model_hook_1.unapply(model)
""" """
pass pass
def build_modeling_hooks(cfg, hook_names: List[str]) -> List[ModelingHook]: def _build_modeling_hooks(cfg, hook_names: List[str]) -> List[ModelingHook]:
"""Build the hooks from cfg""" """Build the hooks from cfg"""
ret = [MODELING_HOOK_REGISTRY.get(name)(cfg) for name in hook_names] ret = [MODELING_HOOK_REGISTRY.get(name)(cfg) for name in hook_names]
return ret return ret
def apply_modeling_hooks( def _unapply_modeling_hook(
model: torch.nn.Module, hooks: List[ModelingHook] model: torch.nn.Module, hooks: List[ModelingHook]
) -> torch.nn.Module: ) -> torch.nn.Module:
"""Apply hooks on the model""" """Call unapply on the hooks in reversed order"""
for hook in reversed(hooks):
model = hook.unapply(model)
return model
def _apply_modeling_hooks(
model: torch.nn.Module, hooks: List[ModelingHook]
) -> torch.nn.Module:
"""Apply hooks on the model, users could call model.unapply_modeling_hooks()
to return the model that removes all the hooks
"""
if len(hooks) == 0:
return model
for hook in hooks: for hook in hooks:
model = hook.apply(model) model = hook.apply(model)
assert not hasattr(model, "_modeling_hooks")
model._modeling_hooks = hooks
def _unapply_modeling_hooks(self):
assert hasattr(self, "_modeling_hooks")
model = _unapply_modeling_hook(self, self._modeling_hooks)
return model
# add a function that could be used to unapply the modeling hooks
assert not hasattr(model, "unapply_modeling_hooks")
model.unapply_modeling_hooks = _unapply_modeling_hooks.__get__(model)
return model
def build_and_apply_modeling_hooks(
model: torch.nn.Module, cfg, hook_names: List[str]
) -> torch.nn.Module:
"""Build modeling hooks from cfg and apply hooks on the model. Users could
call model.unapply_modeling_hooks() to return the model that removes all
the hooks.
"""
hooks = _build_modeling_hooks(cfg, hook_names)
model = _apply_modeling_hooks(model, hooks)
return model return model
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import unittest import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
...@@ -86,6 +87,12 @@ class TestModelingHook(unittest.TestCase): ...@@ -86,6 +87,12 @@ class TestModelingHook(unittest.TestCase):
model = build_model(cfg) model = build_model(cfg)
self.assertEqual(model(2), 10) self.assertEqual(model(2), 10)
self.assertTrue(hasattr(model, "_modeling_hooks"))
self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
orig_model = model.unapply_modeling_hooks()
self.assertIsInstance(orig_model, TestArch)
self.assertEqual(orig_model(2), 4)
def test_modeling_hook_runner(self): def test_modeling_hook_runner(self):
"""Create model with modeling hook from runner""" """Create model with modeling hook from runner"""
runner = default_runner.Detectron2GoRunner() runner = default_runner.Detectron2GoRunner()
...@@ -95,4 +102,30 @@ class TestModelingHook(unittest.TestCase): ...@@ -95,4 +102,30 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"] cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = runner.build_model(cfg) model = runner.build_model(cfg)
self.assertEqual(model(2), 10) self.assertEqual(model(2), 10)
self.assertTrue(hasattr(model, "_modeling_hooks"))
self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
orig_model = model.unapply_modeling_hooks()
self.assertIsInstance(orig_model, TestArch)
self.assertEqual(orig_model(2), 4)
default_runner._close_all_tbx_writers() default_runner._close_all_tbx_writers()
def test_modeling_hook_copy(self):
"""Create model with modeling hook, the model could be copied"""
cfg = CfgNode()
cfg.MODEL = CfgNode()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_model(cfg)
self.assertEqual(model(2), 10)
model_copy = copy.deepcopy(model)
orig_model = model.unapply_modeling_hooks()
self.assertIsInstance(orig_model, TestArch)
self.assertEqual(orig_model(2), 4)
orig_model_copy = model_copy.unapply_modeling_hooks()
self.assertEqual(orig_model_copy(2), 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment