Commit aea25d23 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Apply modeling hook when building the model from cfg.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/207

Apply modeling hook when building the model from cfg.

Differential Revision: D35535571

fbshipit-source-id: e80dd3912911e49c6ed60477f3ba52f74a220dec
parent 2d328adb
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from d2go.modeling.meta_arch import modeling_hook as mh
from d2go.utils.misc import _log_api_usage from d2go.utils.misc import _log_api_usage
from detectron2.modeling import build_model as d2_build_model from detectron2.modeling import build_model as d2_build_model
...@@ -13,5 +14,14 @@ def build_model(cfg): ...@@ -13,5 +14,14 @@ def build_model(cfg):
""" """
meta_arch = cfg.MODEL.META_ARCHITECTURE meta_arch = cfg.MODEL.META_ARCHITECTURE
model = d2_build_model(cfg) model = d2_build_model(cfg)
# apply modeling hooks
# some custom projects bypass d2go's default config so may not have the
# MODELING_HOOKS key
if hasattr(cfg.MODEL, "MODELING_HOOKS"):
hook_names = cfg.MODEL.MODELING_HOOKS
mhooks = mh.build_modeling_hooks(cfg, hook_names)
model = mh.apply_modeling_hooks(model, mhooks)
_log_api_usage("modeling.meta_arch." + meta_arch) _log_api_usage("modeling.meta_arch." + meta_arch)
return model return model
...@@ -2,8 +2,18 @@ ...@@ -2,8 +2,18 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import abstractmethod from abc import abstractmethod
from typing import List
import torch import torch
from detectron2.utils.registry import Registry
MODELING_HOOK_REGISTRY = Registry("MODELING_HOOK") # noqa F401 isort:skip
MODELING_HOOK_REGISTRY.__doc__ = """
Registry for modeling hook.
The registered object will be called with `obj(cfg)`
and expected to return a `ModelingHook` object.
"""
class ModelingHook(object): class ModelingHook(object):
...@@ -38,3 +48,18 @@ class ModelingHook(object): ...@@ -38,3 +48,18 @@ class ModelingHook(object):
model_hook_N-1.unapply(model) -> ... -> model_hook_1.unapply(model) model_hook_N-1.unapply(model) -> ... -> model_hook_1.unapply(model)
""" """
pass pass
def build_modeling_hooks(cfg, hook_names: List[str]) -> List[ModelingHook]:
"""Build the hooks from cfg"""
ret = [MODELING_HOOK_REGISTRY.get(name)(cfg) for name in hook_names]
return ret
def apply_modeling_hooks(
model: torch.nn.Module, hooks: List[ModelingHook]
) -> torch.nn.Module:
"""Apply hooks on the model"""
for hook in hooks:
model = hook.apply(model)
return model
...@@ -89,4 +89,9 @@ def get_default_cfg(_C): ...@@ -89,4 +89,9 @@ def get_default_cfg(_C):
"default_scale_d2_configs", "default_scale_d2_configs",
"default_scale_quantization_configs", "default_scale_quantization_configs",
] ]
# Modeling hooks
# List of modeling hook names
_C.MODEL.MODELING_HOOKS = []
return _C return _C
...@@ -4,12 +4,17 @@ ...@@ -4,12 +4,17 @@
import unittest import unittest
import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.config import CfgNode
from d2go.modeling import build_model
from d2go.modeling.meta_arch import modeling_hook as mh from d2go.modeling.meta_arch import modeling_hook as mh
from detectron2.modeling import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class TestArch(torch.nn.Module): class TestArch(torch.nn.Module):
def __init__(self): def __init__(self, cfg):
super().__init__() super().__init__()
def forward(self, x): def forward(self, x):
...@@ -17,8 +22,8 @@ class TestArch(torch.nn.Module): ...@@ -17,8 +22,8 @@ class TestArch(torch.nn.Module):
# create a wrapper of the model that add 1 to the output # create a wrapper of the model that add 1 to the output
class Wrapper(torch.nn.Module): class PlusOneWrapper(torch.nn.Module):
def __init__(self, model: TestArch): def __init__(self, model: torch.nn.Module):
super().__init__() super().__init__()
self.model = model self.model = model
...@@ -26,23 +31,68 @@ class Wrapper(torch.nn.Module): ...@@ -26,23 +31,68 @@ class Wrapper(torch.nn.Module):
return self.model(x) + 1 return self.model(x) + 1
@mh.MODELING_HOOK_REGISTRY.register()
class PlusOneHook(mh.ModelingHook): class PlusOneHook(mh.ModelingHook):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__(cfg) super().__init__(cfg)
def apply(self, model: torch.nn.Module) -> torch.nn.Module: def apply(self, model: torch.nn.Module) -> torch.nn.Module:
return Wrapper(model) return PlusOneWrapper(model)
def unapply(self, model: torch.nn.Module) -> torch.nn.Module: def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
assert isinstance(model, Wrapper) assert isinstance(model, PlusOneWrapper)
return model.model
# create a wrapper of the model that add 1 to the output
class TimesTwoWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x) * 2
@mh.MODELING_HOOK_REGISTRY.register()
class TimesTwoHook(mh.ModelingHook):
def __init__(self, cfg):
super().__init__(cfg)
def apply(self, model: torch.nn.Module) -> torch.nn.Module:
return TimesTwoWrapper(model)
def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
assert isinstance(model, TimesTwoWrapper)
return model.model return model.model
class TestModelingHook(unittest.TestCase): class TestModelingHook(unittest.TestCase):
def test_modeling_hook_simple(self): def test_modeling_hook_simple(self):
model = TestArch() model = TestArch(None)
hook = PlusOneHook(None) hook = PlusOneHook(None)
model_with_hook = hook.apply(model) model_with_hook = hook.apply(model)
self.assertEqual(model_with_hook(2), 5) self.assertEqual(model_with_hook(2), 5)
original_model = hook.unapply(model_with_hook) original_model = hook.unapply(model_with_hook)
self.assertEqual(model, original_model) self.assertEqual(model, original_model)
def test_modeling_hook_cfg(self):
"""Create model with modeling hook using build_model"""
cfg = CfgNode()
cfg.MODEL = CfgNode()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_model(cfg)
self.assertEqual(model(2), 10)
def test_modeling_hook_runner(self):
"""Create model with modeling hook from runner"""
runner = default_runner.Detectron2GoRunner()
cfg = runner.get_default_cfg()
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = runner.build_model(cfg)
self.assertEqual(model(2), 10)
default_runner._close_all_tbx_writers()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment