Commit d1518ff6 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

remove patch_d2_meta_arch

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/305

One benefit of having separate registries for D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb) and D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go's meta-arch is that there's no need to patch original D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)'s meta arch because we can just register new meta arch in D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go directly. This diff removes the `patch_d2_meta_arch` and makes things simpler.

Reviewed By: mcimpoi

Differential Revision: D37246483

fbshipit-source-id: c8b7adef1fa7a5ff2f89c376c7e3b39bec8f19ee
parent b57fde40
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from functools import lru_cache
from d2go.modeling.meta_arch.rcnn import GeneralizedRCNNPatch
from d2go.modeling.meta_arch.semantic_seg import SemanticSegmentorPatch
from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.modeling import (
GeneralizedRCNN,
META_ARCH_REGISTRY as D2_META_ARCH_REGISTRY,
SemanticSegmentor,
)
logger = logging.getLogger(__name__)
@lru_cache() # only call once
def patch_d2_meta_arch():
"""
Register meta-archietectures that are registered in D2's registry, also convert D2's
meta-arch into D2Go's meta-arch.
D2Go requires interfaces like prepare_for_export/prepare_for_quant from meta-arch in
order to do export/quant, this function applies the monkey patch to the original
D2's meta-archs.
"""
def _check_and_set(cls_obj, method_name, method_func):
if hasattr(cls_obj, method_name):
assert getattr(cls_obj, method_name) == method_func
else:
setattr(cls_obj, method_name, method_func)
def _apply_patch(dst_cls, src_cls):
assert hasattr(src_cls, "METHODS_TO_PATCH")
for method_name in src_cls.METHODS_TO_PATCH:
assert hasattr(src_cls, method_name)
_check_and_set(dst_cls, method_name, getattr(src_cls, method_name))
_apply_patch(GeneralizedRCNN, GeneralizedRCNNPatch)
_apply_patch(SemanticSegmentor, SemanticSegmentorPatch)
# TODO: patch other meta-archs defined in D2
for name, meta_arch_class in D2_META_ARCH_REGISTRY:
logger.info(f"Re-register the D2 meta-arch in D2Go: {meta_arch_class}")
META_ARCH_REGISTRY.register(name, meta_arch_class)
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
# NOTE: making necessary imports to register with Registry # NOTE: making necessary imports to register with Registry
# @fb-only: from . import fb # isort:skip # noqa # @fb-only: from . import fb # isort:skip # noqa
from . import fcos # noqa from . import fcos, panoptic_fpn, rcnn, retinanet, semantic_seg # noqa
...@@ -23,6 +23,7 @@ def add_fcos_configs(cfg): ...@@ -23,6 +23,7 @@ def add_fcos_configs(cfg):
cfg.MODEL.FCOS.FOCAL_LOSS_GAMMA = 2.0 cfg.MODEL.FCOS.FOCAL_LOSS_GAMMA = 2.0
# Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
class FCOS(d2_FCOS): class FCOS(d2_FCOS):
""" """
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.modeling import PanopticFPN as _PanopticFPN
# Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register()
class PanopticFPN(_PanopticFPN):
def prepare_for_export(self, cfg, inputs, predictor_type):
raise NotImplementedError
...@@ -13,7 +13,11 @@ from d2go.config import CfgNode ...@@ -13,7 +13,11 @@ from d2go.config import CfgNode
from d2go.config.utils import flatten_config_dict from d2go.config.utils import flatten_config_dict
from d2go.export.api import PredictorExportConfig from d2go.export.api import PredictorExportConfig
from d2go.quantization.qconfig import set_backend_and_create_qconfig from d2go.quantization.qconfig import set_backend_and_create_qconfig
from detectron2.modeling import GeneralizedRCNN from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.modeling import (
GeneralizedRCNN as _GeneralizedRCNN,
ProposalNetwork as _ProposalNetwork,
)
from detectron2.modeling.backbone.fpn import FPN from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.postprocessing import detector_postprocess from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import PointRendMaskHead from detectron2.projects.point_rend import PointRendMaskHead
...@@ -40,14 +44,9 @@ RCNN_PREPARE_FOR_QUANT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT") ...@@ -40,14 +44,9 @@ RCNN_PREPARE_FOR_QUANT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT")
RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT_CONVERT") RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT_CONVERT")
class GeneralizedRCNNPatch: # Re-register D2's meta-arch in D2Go with updated APIs
METHODS_TO_PATCH = [ @META_ARCH_REGISTRY.register()
"prepare_for_export", class GeneralizedRCNN(_GeneralizedRCNN):
"prepare_for_quant",
"prepare_for_quant_convert",
"_cast_model_to_device",
]
def prepare_for_export(self, cfg, *args, **kwargs): def prepare_for_export(self, cfg, *args, **kwargs):
func = RCNN_PREPARE_FOR_EXPORT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_EXPORT) func = RCNN_PREPARE_FOR_EXPORT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_EXPORT)
return func(self, cfg, *args, **kwargs) return func(self, cfg, *args, **kwargs)
...@@ -66,6 +65,12 @@ class GeneralizedRCNNPatch: ...@@ -66,6 +65,12 @@ class GeneralizedRCNNPatch:
return _cast_detection_model(self, device) return _cast_detection_model(self, device)
# Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register()
class ProposalNetwork(_ProposalNetwork):
pass
@RCNN_PREPARE_FOR_EXPORT_REGISTRY.register() @RCNN_PREPARE_FOR_EXPORT_REGISTRY.register()
def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type): def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
pytorch_model = self pytorch_model = self
...@@ -499,8 +504,6 @@ class D2RCNNInferenceWrapper(nn.Module): ...@@ -499,8 +504,6 @@ class D2RCNNInferenceWrapper(nn.Module):
# TODO: model.to(device) might not work for detection meta-arch, this function is the # TODO: model.to(device) might not work for detection meta-arch, this function is the
# workaround, in general, we might need a meta-arch API for this if needed. # workaround, in general, we might need a meta-arch API for this if needed.
def _cast_detection_model(model, device): def _cast_detection_model(model, device):
from d2go.registry.builtin import META_ARCH_REGISTRY
# check model is an instance of one of the meta arch # check model is an instance of one of the meta arch
from detectron2.export.caffe2_modeling import Caffe2MetaArch from detectron2.export.caffe2_modeling import Caffe2MetaArch
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.modeling import RetinaNet as _RetinaNet
# Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register()
class RetinaNet(_RetinaNet):
def prepare_for_export(self, cfg, inputs, predictor_type):
raise NotImplementedError
...@@ -6,16 +6,16 @@ from typing import Any, Dict, List ...@@ -6,16 +6,16 @@ from typing import Any, Dict, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from d2go.export.api import PredictorExportConfig from d2go.export.api import PredictorExportConfig
from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.modeling import SemanticSegmentor as _SemanticSegmentor
from detectron2.modeling.postprocessing import sem_seg_postprocess from detectron2.modeling.postprocessing import sem_seg_postprocess
from detectron2.structures import ImageList from detectron2.structures import ImageList
from mobile_cv.predictor.api import FuncInfo from mobile_cv.predictor.api import FuncInfo
class SemanticSegmentorPatch: # Re-register D2's meta-arch in D2Go with updated APIs
METHODS_TO_PATCH = [ @META_ARCH_REGISTRY.register()
"prepare_for_export", class SemanticSegmentor(_SemanticSegmentor):
]
def prepare_for_export(self, cfg, inputs, predictor_type): def prepare_for_export(self, cfg, inputs, predictor_type):
preprocess_info = FuncInfo.gen_func_info( preprocess_info = FuncInfo.gen_func_info(
PreprocessFunc, PreprocessFunc,
......
...@@ -22,7 +22,6 @@ from d2go.data.utils import ( ...@@ -22,7 +22,6 @@ from d2go.data.utils import (
maybe_subsample_n_images, maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset, update_cfg_if_using_adhoc_dataset,
) )
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling import build_model, kmeans_anchors, model_ema from d2go.modeling import build_model, kmeans_anchors, model_ema
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
...@@ -209,7 +208,6 @@ class Detectron2GoRunner(BaseRunner): ...@@ -209,7 +208,6 @@ class Detectron2GoRunner(BaseRunner):
inject_coco_datasets(cfg) inject_coco_datasets(cfg)
register_dynamic_datasets(cfg) register_dynamic_datasets(cfg)
update_cfg_if_using_adhoc_dataset(cfg) update_cfg_if_using_adhoc_dataset(cfg)
patch_d2_meta_arch()
@classmethod @classmethod
def get_default_cfg(cls): def get_default_cfg(cls):
......
...@@ -14,7 +14,6 @@ from d2go.config import CfgNode ...@@ -14,7 +14,6 @@ from d2go.config import CfgNode
from d2go.data.build import build_d2go_train_loader from d2go.data.build import build_d2go_train_loader
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.utils import update_cfg_if_using_adhoc_dataset from d2go.data.utils import update_cfg_if_using_adhoc_dataset
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling import build_model from d2go.modeling import build_model
from d2go.modeling.model_freezing_utils import set_requires_grad from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
...@@ -338,7 +337,6 @@ class DefaultTask(pl.LightningModule): ...@@ -338,7 +337,6 @@ class DefaultTask(pl.LightningModule):
inject_coco_datasets(cfg) inject_coco_datasets(cfg)
register_dynamic_datasets(cfg) register_dynamic_datasets(cfg)
update_cfg_if_using_adhoc_dataset(cfg) update_cfg_if_using_adhoc_dataset(cfg)
patch_d2_meta_arch()
@classmethod @classmethod
def build_model(cls, cfg: CfgNode, eval_only=False): def build_model(cls, cfg: CfgNode, eval_only=False):
......
...@@ -9,7 +9,6 @@ from typing import Optional ...@@ -9,7 +9,6 @@ from typing import Optional
import d2go.data.transforms.box_utils as bu import d2go.data.transforms.box_utils as bu
import torch import torch
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.export.exporter import convert_and_export_predictor from d2go.export.exporter import convert_and_export_predictor
from d2go.runner import GeneralizedRCNNRunner from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.data_loader_helper import ( from d2go.utils.testing.data_loader_helper import (
...@@ -254,11 +253,6 @@ class RCNNBaseTestCases: ...@@ -254,11 +253,6 @@ class RCNNBaseTestCases:
class TemplateTestCase(unittest.TestCase): # TODO: maybe subclass from TestMetaArch class TemplateTestCase(unittest.TestCase): # TODO: maybe subclass from TestMetaArch
def setUp(self): def setUp(self):
# Add APIs to D2's meta arch, this is usually called in runner's setup,
# however in unittest it needs to be called sperarately.
# TODO: maybe we should apply this by default
patch_d2_meta_arch()
self.setup_test_dir() self.setup_test_dir()
assert hasattr(self, "test_dir") assert hasattr(self, "test_dir")
......
...@@ -7,7 +7,6 @@ import os ...@@ -7,7 +7,6 @@ import os
import unittest import unittest
import torch import torch
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.export.exporter import convert_and_export_predictor from d2go.export.exporter import convert_and_export_predictor
from d2go.runner import GeneralizedRCNNRunner from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.data_loader_helper import ( from d2go.utils.testing.data_loader_helper import (
...@@ -17,10 +16,6 @@ from d2go.utils.testing.rcnn_helper import get_quick_test_config_opts, RCNNBaseT ...@@ -17,10 +16,6 @@ from d2go.utils.testing.rcnn_helper import get_quick_test_config_opts, RCNNBaseT
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.oss_utils import is_oss from mobile_cv.common.misc.oss_utils import is_oss
# Add APIs to D2's meta arch, this is usually called in runner's setup, however in
# unittest it needs to be called sperarately. (maybe we should apply this by default)
patch_d2_meta_arch()
def _maybe_skip_test(self, predictor_type): def _maybe_skip_test(self, predictor_type):
if is_oss() and "@c2_ops" in predictor_type: if is_oss() and "@c2_ops" in predictor_type:
......
...@@ -8,15 +8,10 @@ import tempfile ...@@ -8,15 +8,10 @@ import tempfile
import unittest import unittest
import torch import torch
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.export.exporter import convert_and_export_predictor from d2go.export.exporter import convert_and_export_predictor
from d2go.runner import Detectron2GoRunner from d2go.runner import Detectron2GoRunner
from mobile_cv.predictor.api import create_predictor from mobile_cv.predictor.api import create_predictor
# Add APIs to D2's meta arch, this is usually called in runner's setup, however in
# unittest it needs to be called sperarately. (maybe we should apply this by default)
patch_d2_meta_arch()
def _get_batch(height, width, is_train): def _get_batch(height, width, is_train):
def _get_frame(): def _get_frame():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment