Commit 4208a791 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

formalize build_d2go_model API

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/318

Reviewed By: mcimpoi

Differential Revision: D37501246

fbshipit-source-id: 6dbe5dcbaf7454f451d4a3bb3fa2d856cc87d5cc
parent 668b7ac2
......@@ -4,10 +4,3 @@
# NOTE: making necessary imports to register with Registery
from . import backbone, meta_arch, modeldef # noqa # noqa # noqa
# namespace forwarding
from .meta_arch.build import build_model
__all__ = [
"build_model",
]
......@@ -2,14 +2,16 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn
from d2go.modeling.meta_arch import modeling_hook as mh
from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.utils.misc import _log_api_usage
from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY
def build_model(cfg):
def build_meta_arch(cfg):
"""
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``.
......@@ -28,6 +30,13 @@ def build_model(cfg):
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
model.to(torch.device(cfg.MODEL.DEVICE))
_log_api_usage("modeling.meta_arch." + meta_arch)
return model
def build_d2go_model(cfg: CfgNode) -> nn.Module:
model = build_meta_arch(cfg)
# apply modeling hooks
# some custom projects bypass d2go's default config so may not have the
# MODELING_HOOKS key
......@@ -35,5 +44,4 @@ def build_model(cfg):
hook_names = cfg.MODEL.MODELING_HOOKS
model = mh.build_and_apply_modeling_hooks(model, cfg, hook_names)
_log_api_usage("modeling.meta_arch." + meta_arch)
return model
......@@ -19,7 +19,7 @@ from typing import List
import torch
import torch.nn as nn
from d2go.config import CfgNode as CN
from d2go.modeling.meta_arch import modeling_hook as mh
from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import (
DISTILLATION_ALGORITHM_REGISTRY,
DISTILLATION_HELPER_REGISTRY,
......
......@@ -22,7 +22,8 @@ from d2go.data.utils import (
maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset,
)
from d2go.modeling import build_model, kmeans_anchors, model_ema
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.api import build_d2go_model
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import QATCheckpointer, QATHook, setup_qat_model
......@@ -164,7 +165,7 @@ class BaseRunner(object):
def build_model(self, cfg, eval_only=False):
# cfg may need to be reused to build trace model again, thus clone
model = build_model(cfg.clone())
model = build_d2go_model(cfg.clone())
if eval_only:
checkpointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR)
......@@ -205,7 +206,7 @@ class Detectron2GoRunner(BaseRunner):
# build_model might modify the cfg, thus clone
cfg = cfg.clone()
model = build_model(cfg)
model = build_d2go_model(cfg)
model_ema.may_build_model_ema(cfg, model)
if cfg.MODEL.FROZEN_LAYER_REG_EXP:
......
......@@ -14,7 +14,7 @@ from d2go.config import CfgNode
from d2go.data.build import build_d2go_train_loader
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import update_cfg_if_using_adhoc_dataset
from d2go.modeling import build_model
from d2go.modeling.api import build_meta_arch
from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
......@@ -128,7 +128,7 @@ class DefaultTask(pl.LightningModule):
self.dataset_evaluators[ModelTag.EMA] = []
def _build_model(self) -> torch.nn.Module:
model = build_model(self.cfg)
model = build_meta_arch(self.cfg)
if self.cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, self.cfg.MODEL.FROZEN_LAYER_REG_EXP, value=False)
......
......@@ -9,6 +9,7 @@ import numpy as np
import torch
import torch.nn as nn
from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh
from d2go.modeling.distillation import (
_build_teacher,
add_distillation_configs,
......@@ -20,7 +21,6 @@ from d2go.modeling.distillation import (
PseudoLabeler,
RelabelTargetInBatch,
)
from d2go.modeling.meta_arch import modeling_hook as mh
from d2go.registry.builtin import (
DISTILLATION_ALGORITHM_REGISTRY,
DISTILLATION_HELPER_REGISTRY,
......
......@@ -8,8 +8,8 @@ import unittest
import d2go.runner.default_runner as default_runner
import torch
from d2go.config import CfgNode
from d2go.modeling import build_model
from d2go.modeling.meta_arch import modeling_hook as mh
from d2go.modeling import modeling_hook as mh
from d2go.modeling.api import build_d2go_model
from d2go.registry.builtin import META_ARCH_REGISTRY
......@@ -84,7 +84,7 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_model(cfg)
model = build_d2go_model(cfg)
self.assertEqual(model(2), 10)
self.assertTrue(hasattr(model, "_modeling_hooks"))
......@@ -118,7 +118,7 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_model(cfg)
model = build_d2go_model(cfg)
self.assertEqual(model(2), 10)
model_copy = copy.deepcopy(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment