Commit e3b5253b authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

Update all registries and fix some ut problems

parent 8dd8da12
......@@ -7,8 +7,7 @@ import numpy as np
from mmdet3d.core import instance_seg_eval, show_result, show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmseg.datasets import DATASETS as SEG_DATASETS
from .builder import DATASETS
from mmdet3d.registry import DATASETS
from .custom_3d import Custom3DDataset
from .custom_3d_seg import Custom3DSegDataset
from .pipelines import Compose
......@@ -253,7 +252,6 @@ class ScanNetDataset(Custom3DDataset):
@DATASETS.register_module()
@SEG_DATASETS.register_module()
class ScanNetSegDataset(Custom3DSegDataset):
r"""ScanNet Dataset for Semantic Segmentation Task.
......@@ -467,7 +465,6 @@ class ScanNetSegDataset(Custom3DSegDataset):
@DATASETS.register_module()
@SEG_DATASETS.register_module()
class ScanNetInstanceSegDataset(Custom3DSegDataset):
CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain',
......
# Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp
from .builder import DATASETS
from mmdet3d.registry import DATASETS
from .custom_3d import Custom3DDataset
......
......@@ -6,8 +6,8 @@ import numpy as np
from mmdet3d.core import show_multi_modality_result, show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.registry import DATASETS
from mmdet.core import eval_map
from .builder import DATASETS
from .custom_3d import Custom3DDataset
from .pipelines import Compose
......
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
from mmcv.transforms import LoadImageFromFile
# yapf: disable
from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
......@@ -10,9 +11,9 @@ from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
LoadPointsFromMultiSweeps,
MultiScaleFlipAug3D,
PointSegClassMapping)
from mmdet.datasets.pipelines import LoadImageFromFile, MultiScaleFlipAug
# yapf: enable
from .builder import PIPELINES
from mmdet3d.registry import TRANSFORMS
from mmdet.datasets.pipelines import MultiScaleFlipAug
def is_loading_function(transform):
......@@ -35,7 +36,7 @@ def is_loading_function(transform):
Collect3D, LoadImageFromFileMono3D,
PointSegClassMapping)
if isinstance(transform, dict):
obj_cls = PIPELINES.get(transform['type'])
obj_cls = TRANSFORMS.get(transform['type'])
if obj_cls is None:
return False
if obj_cls in loading_functions:
......
......@@ -8,8 +8,8 @@ import numpy as np
import torch
from mmcv.utils import print_log
from mmdet3d.registry import DATASETS
from ..core.bbox import Box3DMode, points_cam2img
from .builder import DATASETS
from .kitti_dataset import KittiDataset
......
......@@ -3,10 +3,10 @@ from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn
from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule
from ..builder import BACKBONES
from mmdet3d.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class DGCNNBackbone(BaseModule):
"""Backbone network for DGCNN.
......
......@@ -6,7 +6,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from torch import nn
from ..builder import BACKBONES
from mmdet3d.registry import MODELS
def dla_build_norm_layer(cfg, num_features):
......@@ -275,7 +275,7 @@ class Tree(BaseModule):
return x
@BACKBONES.register_module()
@MODELS.register_module()
class DLANet(BaseModule):
r"""`DLA backbone <https://arxiv.org/abs/1707.06484>`_.
......
......@@ -7,10 +7,11 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn
from ..builder import BACKBONES, build_backbone
from mmdet3d.models.builder import build_backbone
from mmdet3d.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class MultiBackbone(BaseModule):
"""MultiBackbone with different configs.
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.registry import MODELS
from mmdet.models.backbones import RegNet
from ..builder import BACKBONES
@BACKBONES.register_module()
@MODELS.register_module()
class NoStemRegNet(RegNet):
"""RegNet backbone without Stem for 3D detection.
......
......@@ -5,11 +5,11 @@ from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import build_sa_module
from ..builder import BACKBONES
from mmdet3d.registry import MODELS
from .base_pointnet import BasePointNet
@BACKBONES.register_module()
@MODELS.register_module()
class PointNet2SAMSG(BasePointNet):
"""PointNet2 with Multi-scale grouping.
......
......@@ -4,11 +4,11 @@ from mmcv.runner import auto_fp16
from torch import nn as nn
from mmdet3d.ops import PointFPModule, build_sa_module
from ..builder import BACKBONES
from mmdet3d.registry import MODELS
from .base_pointnet import BasePointNet
@BACKBONES.register_module()
@MODELS.register_module()
class PointNet2SASSG(BasePointNet):
"""PointNet2 with Single-scale grouping.
......
......@@ -5,10 +5,10 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from torch import nn as nn
from ..builder import BACKBONES
from mmdet3d.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class SECOND(BaseModule):
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet.
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
from mmdet.models.builder import BACKBONES as MMDET_BACKBONES
from mmdet.models.builder import DETECTORS as MMDET_DETECTORS
from mmdet.models.builder import HEADS as MMDET_HEADS
from mmdet.models.builder import LOSSES as MMDET_LOSSES
from mmdet.models.builder import NECKS as MMDET_NECKS
from mmdet.models.builder import ROI_EXTRACTORS as MMDET_ROI_EXTRACTORS
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
from mmseg.models.builder import LOSSES as MMSEG_LOSSES
MODELS = Registry('models', parent=MMCV_MODELS)
from mmdet3d.registry import MODELS
BACKBONES = MODELS
NECKS = MODELS
......@@ -22,6 +10,7 @@ SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
SEGMENTORS = MODELS
VOXEL_ENCODERS = MODELS
MIDDLE_ENCODERS = MODELS
FUSION_LAYERS = MODELS
......@@ -30,52 +19,47 @@ SEGMENTORS = MODELS
def build_backbone(cfg):
"""Build backbone."""
if cfg['type'] in BACKBONES._module_dict.keys():
return BACKBONES.build(cfg)
else:
return MMDET_BACKBONES.build(cfg)
warnings.warn('``build_backbone`` would be deprecated soon, please use '
'``mmdet3d.registry.MODELS.build()`` ')
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
if cfg['type'] in NECKS._module_dict.keys():
return NECKS.build(cfg)
else:
return MMDET_NECKS.build(cfg)
warnings.warn('``build_neck`` would be deprecated soon, please use '
'``mmdet3d.registry.MODELS.build()`` ')
return NECKS.build(cfg)
def build_roi_extractor(cfg):
"""Build RoI feature extractor."""
if cfg['type'] in ROI_EXTRACTORS._module_dict.keys():
return ROI_EXTRACTORS.build(cfg)
else:
return MMDET_ROI_EXTRACTORS.build(cfg)
"""Build roi extractor."""
warnings.warn(
'``build_roi_extractor`` would be deprecated soon, please use '
'``mmdet3d.registry.MODELS.build()`` ')
return ROI_EXTRACTORS.build(cfg)
def build_shared_head(cfg):
"""Build shared head of detector."""
if cfg['type'] in SHARED_HEADS._module_dict.keys():
return SHARED_HEADS.build(cfg)
else:
return MMDET_SHARED_HEADS.build(cfg)
"""Build shared head."""
warnings.warn('``build_shared_head`` would be deprecated soon, please use '
'``mmdet3d.registry.MODELS.build()`` ')
return SHARED_HEADS.build(cfg)
def build_head(cfg):
"""Build head."""
if cfg['type'] in HEADS._module_dict.keys():
return HEADS.build(cfg)
else:
return MMDET_HEADS.build(cfg)
warnings.warn('``build_head`` would be deprecated soon, please use '
'``mmdet3d.registry.MODELS.build()`` ')
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss function."""
if cfg['type'] in LOSSES._module_dict.keys():
return LOSSES.build(cfg)
elif cfg['type'] in MMDET_LOSSES._module_dict.keys():
return MMDET_LOSSES.build(cfg)
else:
return MMSEG_LOSSES.build(cfg)
"""Build loss."""
warnings.warn('``build_loss`` would be deprecated soon, please use '
'``mmdet3d.registry.MODELS.build()`` ')
return LOSSES.build(cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
......@@ -91,9 +75,6 @@ def build_detector(cfg, train_cfg=None, test_cfg=None):
if cfg['type'] in DETECTORS._module_dict.keys():
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
else:
return MMDET_DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
......@@ -124,14 +105,20 @@ def build_model(cfg, train_cfg=None, test_cfg=None):
def build_voxel_encoder(cfg):
"""Build voxel encoder."""
warnings.warn('``build_voxel_encoder`` would be deprecated soon, please '
'use ``mmdet3d.registry.MODELS.build()`` ')
return VOXEL_ENCODERS.build(cfg)
def build_middle_encoder(cfg):
"""Build middle level encoder."""
warnings.warn('``build_middle_encoder`` would be deprecated soon, please '
'use ``mmdet3d.registry.MODELS.build()`` ')
return MIDDLE_ENCODERS.build(cfg)
def build_fusion_layer(cfg):
"""Build fusion layer."""
warnings.warn('``build_fusion_layer`` would be deprecated soon, please '
'use ``mmdet3d.registry.MODELS.build()`` ')
return FUSION_LAYERS.build(cfg)
......@@ -5,7 +5,7 @@ from mmcv.cnn import normal_init
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from torch import nn as nn
from mmseg.models.builder import build_loss
from ..builder import build_loss
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
......
......@@ -2,11 +2,11 @@
from mmcv.cnn.bricks import ConvModule
from mmdet3d.ops import DGCNNFPModule
from ..builder import HEADS
from mmdet3d.registry import MODELS
from .decode_head import Base3DDecodeHead
@HEADS.register_module()
@MODELS.register_module()
class DGCNNHead(Base3DDecodeHead):
r"""DGCNN decoder head.
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn.bricks import ConvModule
from ..builder import HEADS
from mmdet3d.registry import MODELS
from .pointnet2_head import PointNet2Head
@HEADS.register_module()
@MODELS.register_module()
class PAConvHead(PointNet2Head):
r"""PAConv decoder head.
......
......@@ -3,11 +3,11 @@ from mmcv.cnn.bricks import ConvModule
from torch import nn as nn
from mmdet3d.ops import PointFPModule
from ..builder import HEADS
from mmdet3d.registry import MODELS
from .decode_head import Base3DDecodeHead
@HEADS.register_module()
@MODELS.register_module()
class PointNet2Head(Base3DDecodeHead):
r"""PointNet2 decoder head.
......
......@@ -6,13 +6,14 @@ from torch import nn as nn
from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
xywhr2xyxyr)
from mmdet3d.registry import MODELS
from mmdet.core import (build_assigner, build_bbox_coder,
build_prior_generator, build_sampler, multi_apply)
from ..builder import HEADS, build_loss
from ..builder import build_loss
from .train_mixins import AnchorTrainMixin
@HEADS.register_module()
@MODELS.register_module()
class Anchor3DHead(BaseModule, AnchorTrainMixin):
"""Anchor head for SECOND/PointPillars/MVXNet/PartA2.
......
......@@ -6,12 +6,13 @@ from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
from ..builder import build_loss
from .base_mono3d_dense_head import BaseMono3DDenseHead
@HEADS.register_module()
@MODELS.register_module()
class AnchorFreeMono3DHead(BaseMono3DDenseHead):
"""Anchor-free head for monocular 3D object detection.
......
......@@ -4,10 +4,10 @@ from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule
from torch import nn as nn
from ..builder import HEADS
from mmdet3d.registry import MODELS
@HEADS.register_module()
@MODELS.register_module()
class BaseConvBboxHead(BaseModule):
r"""More general bbox head, with shared conv layers and two optional
separated branches.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment