"vscode:/vscode.git/clone" did not exist on "70fc197be274f053fa6b2226f70cbec9e88110c6"
Commit 3cf04b34 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use d2go's own registry for meta-arch

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/294

Reviewed By: newstzpz

Differential Revision: D37083291

fbshipit-source-id: 2dcf7952021f3b905d2558228862e83afe839675
parent 403a0711
......@@ -7,7 +7,12 @@ from functools import lru_cache
from d2go.modeling.meta_arch.rcnn import GeneralizedRCNNPatch
from d2go.modeling.meta_arch.semantic_seg import SemanticSegmentorPatch
from detectron2.modeling import GeneralizedRCNN, SemanticSegmentor
from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.modeling import (
GeneralizedRCNN,
META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY,
SemanticSegmentor,
)
logger = logging.getLogger(__name__)
......@@ -15,6 +20,9 @@ logger = logging.getLogger(__name__)
@lru_cache() # only call once
def patch_d2_meta_arch():
"""
Register meta-archietectures that are registered in D2's registry, also convert D2's
meta-arch into D2Go's meta-arch.
D2Go requires interfaces like prepare_for_export/prepare_for_quant from meta-arch in
order to do export/quant, this function applies the monkey patch to the original
D2's meta-archs.
......@@ -35,3 +43,7 @@ def patch_d2_meta_arch():
_apply_patch(GeneralizedRCNN, GeneralizedRCNNPatch)
_apply_patch(SemanticSegmentor, SemanticSegmentorPatch)
# TODO: patch other meta-archs defined in D2
for name, meta_arch_class in D2_META_ARCH_REGISTRY:
logger.info(f"Re-register the D2 meta-arch in D2Go: {meta_arch_class}")
META_ARCH_REGISTRY.register(name, meta_arch_class)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from d2go.modeling.meta_arch import modeling_hook as mh
from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.utils.misc import _log_api_usage
from detectron2.modeling import build_model as d2_build_model
from detectron2.modeling import META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY
def build_model(cfg):
......@@ -12,8 +14,19 @@ def build_model(cfg):
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``.
"""
# initialize the meta-arch and cast to the device
meta_arch = cfg.MODEL.META_ARCHITECTURE
model = d2_build_model(cfg)
# NOTE: during transition we also check if meta_arch is registered as D2 MetaArch
# TODO: remove this check after Sep 2022.
if meta_arch not in META_ARCH_REGISTRY and meta_arch in D2_META_ARCH_REGISTRY:
raise KeyError(
f"Can't find '{meta_arch}' in D2Go's META_ARCH_REGISTRY, although it is in"
f" D2's META_ARCH_REGISTRY, now D2Go uses its own registry, please register"
f" it in D2Go's META_ARCH_REGISTRY."
)
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
model.to(torch.device(cfg.MODEL.DEVICE))
# apply modeling hooks
# some custom projects bypass d2go's default config so may not have the
......
......@@ -19,6 +19,9 @@ DEMO_REGISTRY = Registry("DEMO")
# Registry for config updater
CONFIG_UPDATER_REGISTRY = Registry("CONFIG_UPDATER")
# Registry for meta-arch, registered nn.Module should follow D2Go's meta-arch API
META_ARCH_REGISTRY = Registry("META_ARCH")
# Distillation algorithms
DISTILLATION_ALGORITHM_REGISTRY = Registry("DISTILLATION_ALGORITHM")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment