Commit 482fdc8a authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

EMA parity / change build_d2go_model

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/344

we need access to the modeling hooks in EMA, e.g. build trainer.

Reviewed By: wat3rBro

Differential Revision: D37997773

fbshipit-source-id: bf4372cd310605fa35aa70f0604b084b047001d8
parent 7910ab16
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -11,6 +14,20 @@ from d2go.utils.misc import _log_api_usage ...@@ -11,6 +14,20 @@ from d2go.utils.misc import _log_api_usage
from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY
@dataclass
class D2GoModelBuildResult:
"""Class to store the output of build_d2go_model.
It stores the model, a key-value mapping of modeling hooks and can be further
extended with other fields, e.g. state_dict.
"""
# Stores model with applied modeling hooks.
# If modeling hooks (e.g. EMA) are not enabled in config
# the modeling hook will be no-op (e.g. return original model)
model: nn.Module
modeling_hooks: List[mh.ModelingHook]
def build_meta_arch(cfg): def build_meta_arch(cfg):
""" """
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
...@@ -34,14 +51,17 @@ def build_meta_arch(cfg): ...@@ -34,14 +51,17 @@ def build_meta_arch(cfg):
return model return model
def build_d2go_model(cfg: CfgNode) -> nn.Module: def build_d2go_model(
cfg: CfgNode,
) -> D2GoModelBuildResult:
model = build_meta_arch(cfg) model = build_meta_arch(cfg)
modeling_hooks: List[mh.ModelingHook] = []
# apply modeling hooks # apply modeling hooks
# some custom projects bypass d2go's default config so may not have the # some custom projects bypass d2go's default config so may not have the
# MODELING_HOOKS key # MODELING_HOOKS key
if hasattr(cfg.MODEL, "MODELING_HOOKS"): if hasattr(cfg.MODEL, "MODELING_HOOKS"):
hook_names = cfg.MODEL.MODELING_HOOKS hook_names = cfg.MODEL.MODELING_HOOKS
model = mh.build_and_apply_modeling_hooks(model, cfg, hook_names) model, modeling_hooks = mh.build_and_apply_modeling_hooks(
model, cfg, hook_names
return model )
return D2GoModelBuildResult(model=model, modeling_hooks=modeling_hooks)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import abstractmethod from abc import abstractmethod
from typing import List from typing import List, Tuple
import torch import torch
from d2go.registry.builtin import MODELING_HOOK_REGISTRY from d2go.registry.builtin import MODELING_HOOK_REGISTRY
...@@ -85,7 +85,7 @@ def _apply_modeling_hooks( ...@@ -85,7 +85,7 @@ def _apply_modeling_hooks(
def build_and_apply_modeling_hooks( def build_and_apply_modeling_hooks(
model: torch.nn.Module, cfg, hook_names: List[str] model: torch.nn.Module, cfg, hook_names: List[str]
) -> torch.nn.Module: ) -> Tuple[torch.nn.Module, List[ModelingHook]]:
"""Build modeling hooks from cfg and apply hooks on the model. Users could """Build modeling hooks from cfg and apply hooks on the model. Users could
call model.unapply_modeling_hooks() to return the model that removes all call model.unapply_modeling_hooks() to return the model that removes all
the hooks. the hooks.
...@@ -93,4 +93,4 @@ def build_and_apply_modeling_hooks( ...@@ -93,4 +93,4 @@ def build_and_apply_modeling_hooks(
hooks = _build_modeling_hooks(cfg, hook_names) hooks = _build_modeling_hooks(cfg, hook_names)
model = _apply_modeling_hooks(model, hooks) model = _apply_modeling_hooks(model, hooks)
return model return model, hooks
...@@ -58,6 +58,7 @@ from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler ...@@ -58,6 +58,7 @@ from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler
from detectron2.utils.events import CommonMetricPrinter, JSONWriter from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from mobile_cv.common.misc.oss_utils import fb_overwritable from mobile_cv.common.misc.oss_utils import fb_overwritable
from mobile_cv.predictor.api import PredictorWrapper from mobile_cv.predictor.api import PredictorWrapper
from torch import nn
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -163,9 +164,9 @@ class BaseRunner(object): ...@@ -163,9 +164,9 @@ class BaseRunner(object):
def get_default_cfg(cls): def get_default_cfg(cls):
return get_base_runner_default_cfg(CfgNode()) return get_base_runner_default_cfg(CfgNode())
def build_model(self, cfg, eval_only=False): def build_model(self, cfg, eval_only=False) -> nn.Module:
# cfg may need to be reused to build trace model again, thus clone # cfg may need to be reused to build trace model again, thus clone
model = build_d2go_model(cfg.clone()) model = build_d2go_model(cfg.clone()).model
if eval_only: if eval_only:
checkpointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR) checkpointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR)
...@@ -206,7 +207,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -206,7 +207,7 @@ class Detectron2GoRunner(BaseRunner):
# build_model might modify the cfg, thus clone # build_model might modify the cfg, thus clone
cfg = cfg.clone() cfg = cfg.clone()
model = build_d2go_model(cfg) model = build_d2go_model(cfg).model
model_ema.may_build_model_ema(cfg, model) model_ema.may_build_model_ema(cfg, model)
if cfg.MODEL.FROZEN_LAYER_REG_EXP: if cfg.MODEL.FROZEN_LAYER_REG_EXP:
......
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
import copy import copy
import unittest import unittest
from typing import List
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh from d2go.modeling import modeling_hook as mh
from d2go.modeling.api import build_d2go_model from d2go.modeling.api import build_d2go_model, D2GoModelBuildResult
from d2go.registry.builtin import META_ARCH_REGISTRY, MODELING_HOOK_REGISTRY from d2go.registry.builtin import META_ARCH_REGISTRY, MODELING_HOOK_REGISTRY
...@@ -84,8 +85,13 @@ class TestModelingHook(unittest.TestCase): ...@@ -84,8 +85,13 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.DEVICE = "cpu" cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch" cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"] cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_d2go_model(cfg)
model_info: D2GoModelBuildResult = build_d2go_model(cfg)
model: torch.nn.Module = model_info.model
modeling_hooks: List[mh.ModelingHook] = model_info.modeling_hooks
self.assertEqual(model(2), 10) self.assertEqual(model(2), 10)
self.assertEqual(len(modeling_hooks), 2)
self.assertTrue(hasattr(model, "_modeling_hooks")) self.assertTrue(hasattr(model, "_modeling_hooks"))
self.assertTrue(hasattr(model, "unapply_modeling_hooks")) self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
...@@ -118,8 +124,13 @@ class TestModelingHook(unittest.TestCase): ...@@ -118,8 +124,13 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.DEVICE = "cpu" cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch" cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"] cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_d2go_model(cfg)
model_info: D2GoModelBuildResult = build_d2go_model(cfg)
model: torch.nn.Module = model_info.model
modeling_hooks: List[mh.ModelingHook] = model_info.modeling_hooks
self.assertEqual(model(2), 10) self.assertEqual(model(2), 10)
self.assertEqual(len(modeling_hooks), 2)
model_copy = copy.deepcopy(model) model_copy = copy.deepcopy(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment