Commit 091213f0 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Support unapplying modeling hooks.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/208

Support unapplying modeling hooks.

Reviewed By: tglik

Differential Revision: D35540649

fbshipit-source-id: 60cc5e214282e30b39fc98ba4d58dad2fc6ea086
parent aea25d23
......@@ -20,8 +20,7 @@ def build_model(cfg):
# MODELING_HOOKS key
if hasattr(cfg.MODEL, "MODELING_HOOKS"):
hook_names = cfg.MODEL.MODELING_HOOKS
mhooks = mh.build_modeling_hooks(cfg, hook_names)
model = mh.apply_modeling_hooks(model, mhooks)
model = mh.build_and_apply_modeling_hooks(model, cfg, hook_names)
_log_api_usage("modeling.meta_arch." + meta_arch)
return model
......@@ -40,26 +40,65 @@ class ModelingHook(object):
@abstractmethod
def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
"""This function will be called when the users called model.get_exportable_model()
"""This function will be called when the users called model.unapply_modeling_hooks()
after training. The main use case of the function is to remove the changes
applied to the model in `apply`. The hooks will be called in reverse order
as follow:
model.get_exportable_model() == model_hook_N.unapply(model) ->
model.unapply_modeling_hooks() == model_hook_N.unapply(model) ->
model_hook_N-1.unapply(model) -> ... -> model_hook_1.unapply(model)
"""
pass
def build_modeling_hooks(cfg, hook_names: List[str]) -> List[ModelingHook]:
def _build_modeling_hooks(cfg, hook_names: List[str]) -> List[ModelingHook]:
"""Build the hooks from cfg"""
ret = [MODELING_HOOK_REGISTRY.get(name)(cfg) for name in hook_names]
return ret
def apply_modeling_hooks(
def _unapply_modeling_hook(
model: torch.nn.Module, hooks: List[ModelingHook]
) -> torch.nn.Module:
"""Apply hooks on the model"""
"""Call unapply on the hooks in reversed order"""
for hook in reversed(hooks):
model = hook.unapply(model)
return model
def _apply_modeling_hooks(
model: torch.nn.Module, hooks: List[ModelingHook]
) -> torch.nn.Module:
"""Apply hooks on the model, users could call model.unapply_modeling_hooks()
to return the model that removes all the hooks
"""
if len(hooks) == 0:
return model
for hook in hooks:
model = hook.apply(model)
assert not hasattr(model, "_modeling_hooks")
model._modeling_hooks = hooks
def _unapply_modeling_hooks(self):
assert hasattr(self, "_modeling_hooks")
model = _unapply_modeling_hook(self, self._modeling_hooks)
return model
# add a function that could be used to unapply the modeling hooks
assert not hasattr(model, "unapply_modeling_hooks")
model.unapply_modeling_hooks = _unapply_modeling_hooks.__get__(model)
return model
def build_and_apply_modeling_hooks(
model: torch.nn.Module, cfg, hook_names: List[str]
) -> torch.nn.Module:
"""Build modeling hooks from cfg and apply hooks on the model. Users could
call model.unapply_modeling_hooks() to return the model that removes all
the hooks.
"""
hooks = _build_modeling_hooks(cfg, hook_names)
model = _apply_modeling_hooks(model, hooks)
return model
......@@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import unittest
import d2go.runner.default_runner as default_runner
......@@ -86,6 +87,12 @@ class TestModelingHook(unittest.TestCase):
model = build_model(cfg)
self.assertEqual(model(2), 10)
self.assertTrue(hasattr(model, "_modeling_hooks"))
self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
orig_model = model.unapply_modeling_hooks()
self.assertIsInstance(orig_model, TestArch)
self.assertEqual(orig_model(2), 4)
def test_modeling_hook_runner(self):
"""Create model with modeling hook from runner"""
runner = default_runner.Detectron2GoRunner()
......@@ -95,4 +102,30 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = runner.build_model(cfg)
self.assertEqual(model(2), 10)
self.assertTrue(hasattr(model, "_modeling_hooks"))
self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
orig_model = model.unapply_modeling_hooks()
self.assertIsInstance(orig_model, TestArch)
self.assertEqual(orig_model(2), 4)
default_runner._close_all_tbx_writers()
def test_modeling_hook_copy(self):
"""Create model with modeling hook, the model could be copied"""
cfg = CfgNode()
cfg.MODEL = CfgNode()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_model(cfg)
self.assertEqual(model(2), 10)
model_copy = copy.deepcopy(model)
orig_model = model.unapply_modeling_hooks()
self.assertIsInstance(orig_model, TestArch)
self.assertEqual(orig_model(2), 4)
orig_model_copy = model_copy.unapply_modeling_hooks()
self.assertEqual(orig_model_copy(2), 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment