Commit aea25d23 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Apply modeling hook when building the model from cfg.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/207

Apply modeling hook when building the model from cfg.

Differential Revision: D35535571

fbshipit-source-id: e80dd3912911e49c6ed60477f3ba52f74a220dec
parent 2d328adb
......@@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from d2go.modeling.meta_arch import modeling_hook as mh
from d2go.utils.misc import _log_api_usage
from detectron2.modeling import build_model as d2_build_model
......@@ -13,5 +14,14 @@ def build_model(cfg):
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
model = d2_build_model(cfg)
# apply modeling hooks
# some custom projects bypass d2go's default config so may not have the
# MODELING_HOOKS key
if hasattr(cfg.MODEL, "MODELING_HOOKS"):
hook_names = cfg.MODEL.MODELING_HOOKS
mhooks = mh.build_modeling_hooks(cfg, hook_names)
model = mh.apply_modeling_hooks(model, mhooks)
_log_api_usage("modeling.meta_arch." + meta_arch)
return model
......@@ -2,8 +2,18 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import abstractmethod
from typing import List
import torch
from detectron2.utils.registry import Registry
MODELING_HOOK_REGISTRY = Registry("MODELING_HOOK") # noqa F401 isort:skip
MODELING_HOOK_REGISTRY.__doc__ = """
Registry for modeling hook.
The registered object will be called with `obj(cfg)`
and expected to return a `ModelingHook` object.
"""
class ModelingHook(object):
......@@ -38,3 +48,18 @@ class ModelingHook(object):
model_hook_N-1.unapply(model) -> ... -> model_hook_1.unapply(model)
"""
pass
def build_modeling_hooks(cfg, hook_names: List[str]) -> List[ModelingHook]:
"""Build the hooks from cfg"""
ret = [MODELING_HOOK_REGISTRY.get(name)(cfg) for name in hook_names]
return ret
def apply_modeling_hooks(
model: torch.nn.Module, hooks: List[ModelingHook]
) -> torch.nn.Module:
"""Apply hooks on the model"""
for hook in hooks:
model = hook.apply(model)
return model
......@@ -89,4 +89,9 @@ def get_default_cfg(_C):
"default_scale_d2_configs",
"default_scale_quantization_configs",
]
# Modeling hooks
# List of modeling hook names
_C.MODEL.MODELING_HOOKS = []
return _C
......@@ -4,12 +4,17 @@
import unittest
import d2go.runner.default_runner as default_runner
import torch
from d2go.config import CfgNode
from d2go.modeling import build_model
from d2go.modeling.meta_arch import modeling_hook as mh
from detectron2.modeling import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class TestArch(torch.nn.Module):
def __init__(self):
def __init__(self, cfg):
super().__init__()
def forward(self, x):
......@@ -17,8 +22,8 @@ class TestArch(torch.nn.Module):
# create a wrapper of the model that add 1 to the output
class Wrapper(torch.nn.Module):
def __init__(self, model: TestArch):
class PlusOneWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
......@@ -26,23 +31,68 @@ class Wrapper(torch.nn.Module):
return self.model(x) + 1
@mh.MODELING_HOOK_REGISTRY.register()
class PlusOneHook(mh.ModelingHook):
def __init__(self, cfg):
super().__init__(cfg)
def apply(self, model: torch.nn.Module) -> torch.nn.Module:
return Wrapper(model)
return PlusOneWrapper(model)
def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
assert isinstance(model, Wrapper)
assert isinstance(model, PlusOneWrapper)
return model.model
# create a wrapper of the model that add 1 to the output
class TimesTwoWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x) * 2
@mh.MODELING_HOOK_REGISTRY.register()
class TimesTwoHook(mh.ModelingHook):
def __init__(self, cfg):
super().__init__(cfg)
def apply(self, model: torch.nn.Module) -> torch.nn.Module:
return TimesTwoWrapper(model)
def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
assert isinstance(model, TimesTwoWrapper)
return model.model
class TestModelingHook(unittest.TestCase):
def test_modeling_hook_simple(self):
model = TestArch()
model = TestArch(None)
hook = PlusOneHook(None)
model_with_hook = hook.apply(model)
self.assertEqual(model_with_hook(2), 5)
original_model = hook.unapply(model_with_hook)
self.assertEqual(model, original_model)
def test_modeling_hook_cfg(self):
"""Create model with modeling hook using build_model"""
cfg = CfgNode()
cfg.MODEL = CfgNode()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_model(cfg)
self.assertEqual(model(2), 10)
def test_modeling_hook_runner(self):
"""Create model with modeling hook from runner"""
runner = default_runner.Detectron2GoRunner()
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = runner.build_model(cfg)
self.assertEqual(model(2), 10)
default_runner._close_all_tbx_writers()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment