Commit 17fc9270 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

update to use D2Go's META_ARCH_REGISTRY and build_model

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/298

Reviewed By: tglik, newstzpz

Differential Revision: D37152248

fbshipit-source-id: 58a6899c5f6465f36961a2ebf60a64f20509cec2
parent 3cf04b34
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.config import configurable from detectron2.config import configurable
from detectron2.modeling.backbone import build_backbone from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.meta_arch import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.fcos import FCOS as d2_FCOS, FCOSHead from detectron2.modeling.meta_arch.fcos import FCOS as d2_FCOS, FCOSHead
......
...@@ -499,9 +499,10 @@ class D2RCNNInferenceWrapper(nn.Module): ...@@ -499,9 +499,10 @@ class D2RCNNInferenceWrapper(nn.Module):
# TODO: model.to(device) might not work for detection meta-arch, this function is the # TODO: model.to(device) might not work for detection meta-arch, this function is the
# workaround, in general, we might need a meta-arch API for this if needed. # workaround, in general, we might need a meta-arch API for this if needed.
def _cast_detection_model(model, device): def _cast_detection_model(model, device):
from d2go.registry.builtin import META_ARCH_REGISTRY
# check model is an instance of one of the meta arch # check model is an instance of one of the meta arch
from detectron2.export.caffe2_modeling import Caffe2MetaArch from detectron2.export.caffe2_modeling import Caffe2MetaArch
from detectron2.modeling import META_ARCH_REGISTRY
if isinstance(model, Caffe2MetaArch): if isinstance(model, Caffe2MetaArch):
model._wrapped_model = _cast_detection_model(model._wrapped_model, device) model._wrapped_model = _cast_detection_model(model._wrapped_model, device)
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import torch import torch
from d2go.quantization.modeling import set_backend_and_create_qconfig from d2go.quantization.modeling import set_backend_and_create_qconfig
from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances from detectron2.structures import Boxes, ImageList, Instances
from torch.ao.quantization.quantize_fx import convert_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_qat_fx
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
from typing import Optional, Type from typing import Optional, Type
from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.data import DatasetCatalog, detection_utils as utils, MetadataCatalog from detectron2.data import DatasetCatalog, detection_utils as utils, MetadataCatalog
from detectron2.evaluation import DatasetEvaluator from detectron2.evaluation import DatasetEvaluator
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.utils.events import get_event_storage from detectron2.utils.events import get_event_storage
from detectron2.utils.visualizer import Visualizer from detectron2.utils.visualizer import Visualizer
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from detectron2.modeling import detector_postprocess, META_ARCH_REGISTRY from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.modeling import detector_postprocess
from detectron2.structures import BitMasks, Boxes, ImageList, Instances from detectron2.structures import BitMasks, Boxes, ImageList, Instances
from detr.datasets.coco import convert_coco_poly_to_mask from detr.datasets.coco import convert_coco_poly_to_mask
from detr.models.backbone import Joiner from detr.models.backbone import Joiner
......
...@@ -59,7 +59,7 @@ class TestMetaArchRegistryPopulation(unittest.TestCase, BaseRegistryPopulationTe ...@@ -59,7 +59,7 @@ class TestMetaArchRegistryPopulation(unittest.TestCase, BaseRegistryPopulationTe
self._package = d2go.modeling self._package = d2go.modeling
def get_registered_items(self): def get_registered_items(self):
from detectron2.modeling import META_ARCH_REGISTRY from d2go.registry.builtin import META_ARCH_REGISTRY
return [k for k, v in META_ARCH_REGISTRY] return [k for k, v in META_ARCH_REGISTRY]
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.modeling import build_model from d2go.modeling import build_model
from d2go.modeling.meta_arch import modeling_hook as mh from d2go.modeling.meta_arch import modeling_hook as mh
from detectron2.modeling import META_ARCH_REGISTRY from d2go.registry.builtin import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
......
...@@ -9,12 +9,12 @@ import unittest ...@@ -9,12 +9,12 @@ import unittest
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import torch import torch
from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.training_hooks import TRAINER_HOOKS_REGISTRY from d2go.runner.training_hooks import TRAINER_HOOKS_REGISTRY
from d2go.utils.testing import helper from d2go.utils.testing import helper
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator from detectron2.evaluation import COCOEvaluator, RotatedCOCOEvaluator
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances from detectron2.structures import Boxes, ImageList, Instances
from mobile_cv.arch.quantization.qconfig import ( from mobile_cv.arch.quantization.qconfig import (
updateable_symmetric_moving_avg_minmax_config, updateable_symmetric_moving_avg_minmax_config,
......
...@@ -11,12 +11,12 @@ import pytorch_lightning as pl # type: ignore ...@@ -11,12 +11,12 @@ import pytorch_lightning as pl # type: ignore
import torch import torch
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.quantization.modeling import set_backend_and_create_qconfig from d2go.quantization.modeling import set_backend_and_create_qconfig
from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import GeneralizedRCNNTask from d2go.runner.lightning_task import GeneralizedRCNNTask
from d2go.utils.testing import meta_arch_helper as mah from d2go.utils.testing import meta_arch_helper as mah
from d2go.utils.testing.helper import tempdir from d2go.utils.testing.helper import tempdir
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor from torch import Tensor
......
...@@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Tuple
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import numpy as np import numpy as np
import torch import torch
from d2go.registry.builtin import META_ARCH_REGISTRY
from d2go.utils.testing.data_loader_helper import ( from d2go.utils.testing.data_loader_helper import (
create_toy_dataset, create_toy_dataset,
LocalImageGenerator, LocalImageGenerator,
...@@ -18,7 +19,6 @@ from d2go.utils.testing.data_loader_helper import ( ...@@ -18,7 +19,6 @@ from d2go.utils.testing.data_loader_helper import (
from d2go.utils.testing.helper import tempdir from d2go.utils.testing.helper import tempdir
from d2go.utils.visualization import DataLoaderVisWrapper, VisualizerWrapper from d2go.utils.visualization import DataLoaderVisWrapper, VisualizerWrapper
from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, Instances from detectron2.structures import Boxes, Instances
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment