Commit 482fdc8a authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

EMA parity / change build_d2go_model

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/344

we need access to the modeling hooks in EMA, e.g. build trainer.

Reviewed By: wat3rBro

Differential Revision: D37997773

fbshipit-source-id: bf4372cd310605fa35aa70f0604b084b047001d8
parent 7910ab16
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from dataclasses import dataclass
from typing import List
import torch
import torch.nn as nn
......@@ -11,6 +14,20 @@ from d2go.utils.misc import _log_api_usage
from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY
@dataclass
class D2GoModelBuildResult:
"""Class to store the output of build_d2go_model.
It stores the model, a key-value mapping of modeling hooks and can be further
extended with other fields, e.g. state_dict.
"""
# Stores model with applied modeling hooks.
# If modeling hooks (e.g. EMA) are not enabled in config
# the modeling hook will be no-op (e.g. return original model)
model: nn.Module
modeling_hooks: List[mh.ModelingHook]
def build_meta_arch(cfg):
"""
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
......@@ -34,14 +51,17 @@ def build_meta_arch(cfg):
return model
def build_d2go_model(cfg: CfgNode) -> nn.Module:
def build_d2go_model(
cfg: CfgNode,
) -> D2GoModelBuildResult:
model = build_meta_arch(cfg)
modeling_hooks: List[mh.ModelingHook] = []
# apply modeling hooks
# some custom projects bypass d2go's default config so may not have the
# MODELING_HOOKS key
if hasattr(cfg.MODEL, "MODELING_HOOKS"):
hook_names = cfg.MODEL.MODELING_HOOKS
model = mh.build_and_apply_modeling_hooks(model, cfg, hook_names)
return model
model, modeling_hooks = mh.build_and_apply_modeling_hooks(
model, cfg, hook_names
)
return D2GoModelBuildResult(model=model, modeling_hooks=modeling_hooks)
......@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from abc import abstractmethod
from typing import List
from typing import List, Tuple
import torch
from d2go.registry.builtin import MODELING_HOOK_REGISTRY
......@@ -85,7 +85,7 @@ def _apply_modeling_hooks(
def build_and_apply_modeling_hooks(
model: torch.nn.Module, cfg, hook_names: List[str]
) -> torch.nn.Module:
) -> Tuple[torch.nn.Module, List[ModelingHook]]:
"""Build modeling hooks from cfg and apply hooks on the model. Users could
call model.unapply_modeling_hooks() to return the model that removes all
the hooks.
......@@ -93,4 +93,4 @@ def build_and_apply_modeling_hooks(
hooks = _build_modeling_hooks(cfg, hook_names)
model = _apply_modeling_hooks(model, hooks)
return model
return model, hooks
......@@ -58,6 +58,7 @@ from detectron2.solver import build_lr_scheduler as d2_build_lr_scheduler
from detectron2.utils.events import CommonMetricPrinter, JSONWriter
from mobile_cv.common.misc.oss_utils import fb_overwritable
from mobile_cv.predictor.api import PredictorWrapper
from torch import nn
logger = logging.getLogger(__name__)
......@@ -163,9 +164,9 @@ class BaseRunner(object):
def get_default_cfg(cls):
return get_base_runner_default_cfg(CfgNode())
def build_model(self, cfg, eval_only=False):
def build_model(self, cfg, eval_only=False) -> nn.Module:
# cfg may need to be reused to build trace model again, thus clone
model = build_d2go_model(cfg.clone())
model = build_d2go_model(cfg.clone()).model
if eval_only:
checkpointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR)
......@@ -206,7 +207,7 @@ class Detectron2GoRunner(BaseRunner):
# build_model might modify the cfg, thus clone
cfg = cfg.clone()
model = build_d2go_model(cfg)
model = build_d2go_model(cfg).model
model_ema.may_build_model_ema(cfg, model)
if cfg.MODEL.FROZEN_LAYER_REG_EXP:
......
......@@ -4,12 +4,13 @@
import copy
import unittest
from typing import List
import d2go.runner.default_runner as default_runner
import torch
from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh
from d2go.modeling.api import build_d2go_model
from d2go.modeling.api import build_d2go_model, D2GoModelBuildResult
from d2go.registry.builtin import META_ARCH_REGISTRY, MODELING_HOOK_REGISTRY
......@@ -84,8 +85,13 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_d2go_model(cfg)
model_info: D2GoModelBuildResult = build_d2go_model(cfg)
model: torch.nn.Module = model_info.model
modeling_hooks: List[mh.ModelingHook] = model_info.modeling_hooks
self.assertEqual(model(2), 10)
self.assertEqual(len(modeling_hooks), 2)
self.assertTrue(hasattr(model, "_modeling_hooks"))
self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
......@@ -118,8 +124,13 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_d2go_model(cfg)
model_info: D2GoModelBuildResult = build_d2go_model(cfg)
model: torch.nn.Module = model_info.model
modeling_hooks: List[mh.ModelingHook] = model_info.modeling_hooks
self.assertEqual(model(2), 10)
self.assertEqual(len(modeling_hooks), 2)
model_copy = copy.deepcopy(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment