Commit 4208a791 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

formalize build_d2go_model API

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/318

Reviewed By: mcimpoi

Differential Revision: D37501246

fbshipit-source-id: 6dbe5dcbaf7454f451d4a3bb3fa2d856cc87d5cc
parent 668b7ac2
...@@ -4,10 +4,3 @@ ...@@ -4,10 +4,3 @@
# NOTE: making necessary imports to register with Registery # NOTE: making necessary imports to register with Registery
from . import backbone, meta_arch, modeldef # noqa # noqa # noqa from . import backbone, meta_arch, modeldef # noqa # noqa # noqa
# namespace forwarding
from .meta_arch.build import build_model
__all__ = [
"build_model",
]
...@@ -2,14 +2,16 @@ ...@@ -2,14 +2,16 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch import torch
import torch.nn as nn
from d2go.modeling.meta_arch import modeling_hook as mh from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import META_ARCH_REGISTRY from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.utils.misc import _log_api_usage from d2go.utils.misc import _log_api_usage
from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY
def build_model(cfg): def build_meta_arch(cfg):
""" """
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``. Note that it does not load any weights from ``cfg``.
...@@ -28,6 +30,13 @@ def build_model(cfg): ...@@ -28,6 +30,13 @@ def build_model(cfg):
model = META_ARCH_REGISTRY.get(meta_arch)(cfg) model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
model.to(torch.device(cfg.MODEL.DEVICE)) model.to(torch.device(cfg.MODEL.DEVICE))
_log_api_usage("modeling.meta_arch." + meta_arch)
return model
def build_d2go_model(cfg: CfgNode) -> nn.Module:
model = build_meta_arch(cfg)
# apply modeling hooks # apply modeling hooks
# some custom projects bypass d2go's default config so may not have the # some custom projects bypass d2go's default config so may not have the
# MODELING_HOOKS key # MODELING_HOOKS key
...@@ -35,5 +44,4 @@ def build_model(cfg): ...@@ -35,5 +44,4 @@ def build_model(cfg):
hook_names = cfg.MODEL.MODELING_HOOKS hook_names = cfg.MODEL.MODELING_HOOKS
model = mh.build_and_apply_modeling_hooks(model, cfg, hook_names) model = mh.build_and_apply_modeling_hooks(model, cfg, hook_names)
_log_api_usage("modeling.meta_arch." + meta_arch)
return model return model
...@@ -19,7 +19,7 @@ from typing import List ...@@ -19,7 +19,7 @@ from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.modeling.meta_arch import modeling_hook as mh from d2go.modeling import modeling_hook as mh
from d2go.registry.builtin import ( from d2go.registry.builtin import (
DISTILLATION_ALGORITHM_REGISTRY, DISTILLATION_ALGORITHM_REGISTRY,
DISTILLATION_HELPER_REGISTRY, DISTILLATION_HELPER_REGISTRY,
......
...@@ -22,7 +22,8 @@ from d2go.data.utils import ( ...@@ -22,7 +22,8 @@ from d2go.data.utils import (
maybe_subsample_n_images, maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset, update_cfg_if_using_adhoc_dataset,
) )
from d2go.modeling import build_model, kmeans_anchors, model_ema from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.api import build_d2go_model
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import QATCheckpointer, QATHook, setup_qat_model from d2go.quantization.modeling import QATCheckpointer, QATHook, setup_qat_model
...@@ -164,7 +165,7 @@ class BaseRunner(object): ...@@ -164,7 +165,7 @@ class BaseRunner(object):
def build_model(self, cfg, eval_only=False): def build_model(self, cfg, eval_only=False):
# cfg may need to be reused to build trace model again, thus clone # cfg may need to be reused to build trace model again, thus clone
model = build_model(cfg.clone()) model = build_d2go_model(cfg.clone())
if eval_only: if eval_only:
checkpointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR) checkpointer = DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR)
...@@ -205,7 +206,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -205,7 +206,7 @@ class Detectron2GoRunner(BaseRunner):
# build_model might modify the cfg, thus clone # build_model might modify the cfg, thus clone
cfg = cfg.clone() cfg = cfg.clone()
model = build_model(cfg) model = build_d2go_model(cfg)
model_ema.may_build_model_ema(cfg, model) model_ema.may_build_model_ema(cfg, model)
if cfg.MODEL.FROZEN_LAYER_REG_EXP: if cfg.MODEL.FROZEN_LAYER_REG_EXP:
......
...@@ -14,7 +14,7 @@ from d2go.config import CfgNode ...@@ -14,7 +14,7 @@ from d2go.config import CfgNode
from d2go.data.build import build_d2go_train_loader from d2go.data.build import build_d2go_train_loader
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import update_cfg_if_using_adhoc_dataset from d2go.data.utils import update_cfg_if_using_adhoc_dataset
from d2go.modeling import build_model from d2go.modeling.api import build_meta_arch
from d2go.modeling.model_freezing_utils import set_requires_grad from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import ( from d2go.quantization.modeling import (
...@@ -128,7 +128,7 @@ class DefaultTask(pl.LightningModule): ...@@ -128,7 +128,7 @@ class DefaultTask(pl.LightningModule):
self.dataset_evaluators[ModelTag.EMA] = [] self.dataset_evaluators[ModelTag.EMA] = []
def _build_model(self) -> torch.nn.Module: def _build_model(self) -> torch.nn.Module:
model = build_model(self.cfg) model = build_meta_arch(self.cfg)
if self.cfg.MODEL.FROZEN_LAYER_REG_EXP: if self.cfg.MODEL.FROZEN_LAYER_REG_EXP:
set_requires_grad(model, self.cfg.MODEL.FROZEN_LAYER_REG_EXP, value=False) set_requires_grad(model, self.cfg.MODEL.FROZEN_LAYER_REG_EXP, value=False)
......
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.modeling import modeling_hook as mh
from d2go.modeling.distillation import ( from d2go.modeling.distillation import (
_build_teacher, _build_teacher,
add_distillation_configs, add_distillation_configs,
...@@ -20,7 +21,6 @@ from d2go.modeling.distillation import ( ...@@ -20,7 +21,6 @@ from d2go.modeling.distillation import (
PseudoLabeler, PseudoLabeler,
RelabelTargetInBatch, RelabelTargetInBatch,
) )
from d2go.modeling.meta_arch import modeling_hook as mh
from d2go.registry.builtin import ( from d2go.registry.builtin import (
DISTILLATION_ALGORITHM_REGISTRY, DISTILLATION_ALGORITHM_REGISTRY,
DISTILLATION_HELPER_REGISTRY, DISTILLATION_HELPER_REGISTRY,
......
...@@ -8,8 +8,8 @@ import unittest ...@@ -8,8 +8,8 @@ import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.modeling import build_model from d2go.modeling import modeling_hook as mh
from d2go.modeling.meta_arch import modeling_hook as mh from d2go.modeling.api import build_d2go_model
from d2go.registry.builtin import META_ARCH_REGISTRY from d2go.registry.builtin import META_ARCH_REGISTRY
...@@ -84,7 +84,7 @@ class TestModelingHook(unittest.TestCase): ...@@ -84,7 +84,7 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.DEVICE = "cpu" cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch" cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"] cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_model(cfg) model = build_d2go_model(cfg)
self.assertEqual(model(2), 10) self.assertEqual(model(2), 10)
self.assertTrue(hasattr(model, "_modeling_hooks")) self.assertTrue(hasattr(model, "_modeling_hooks"))
...@@ -118,7 +118,7 @@ class TestModelingHook(unittest.TestCase): ...@@ -118,7 +118,7 @@ class TestModelingHook(unittest.TestCase):
cfg.MODEL.DEVICE = "cpu" cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.META_ARCHITECTURE = "TestArch" cfg.MODEL.META_ARCHITECTURE = "TestArch"
cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"] cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
model = build_model(cfg) model = build_d2go_model(cfg)
self.assertEqual(model(2), 10) self.assertEqual(model(2), 10)
model_copy = copy.deepcopy(model) model_copy = copy.deepcopy(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment