Commit e3b5253b authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

Update all registries and fix some ut problems

parent 8dd8da12
...@@ -7,8 +7,7 @@ import numpy as np ...@@ -7,8 +7,7 @@ import numpy as np
from mmdet3d.core import instance_seg_eval, show_result, show_seg_result from mmdet3d.core import instance_seg_eval, show_result, show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmseg.datasets import DATASETS as SEG_DATASETS from mmdet3d.registry import DATASETS
from .builder import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .custom_3d_seg import Custom3DSegDataset from .custom_3d_seg import Custom3DSegDataset
from .pipelines import Compose from .pipelines import Compose
...@@ -253,7 +252,6 @@ class ScanNetDataset(Custom3DDataset): ...@@ -253,7 +252,6 @@ class ScanNetDataset(Custom3DDataset):
@DATASETS.register_module() @DATASETS.register_module()
@SEG_DATASETS.register_module()
class ScanNetSegDataset(Custom3DSegDataset): class ScanNetSegDataset(Custom3DSegDataset):
r"""ScanNet Dataset for Semantic Segmentation Task. r"""ScanNet Dataset for Semantic Segmentation Task.
...@@ -467,7 +465,6 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -467,7 +465,6 @@ class ScanNetSegDataset(Custom3DSegDataset):
@DATASETS.register_module() @DATASETS.register_module()
@SEG_DATASETS.register_module()
class ScanNetInstanceSegDataset(Custom3DSegDataset): class ScanNetInstanceSegDataset(Custom3DSegDataset):
CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp from os import path as osp
from .builder import DATASETS from mmdet3d.registry import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
......
...@@ -6,8 +6,8 @@ import numpy as np ...@@ -6,8 +6,8 @@ import numpy as np
from mmdet3d.core import show_multi_modality_result, show_result from mmdet3d.core import show_multi_modality_result, show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.registry import DATASETS
from mmdet.core import eval_map from mmdet.core import eval_map
from .builder import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .pipelines import Compose from .pipelines import Compose
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import mmcv import mmcv
from mmcv.transforms import LoadImageFromFile
# yapf: disable # yapf: disable
from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D, from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
...@@ -10,9 +11,9 @@ from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D, ...@@ -10,9 +11,9 @@ from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
LoadPointsFromMultiSweeps, LoadPointsFromMultiSweeps,
MultiScaleFlipAug3D, MultiScaleFlipAug3D,
PointSegClassMapping) PointSegClassMapping)
from mmdet.datasets.pipelines import LoadImageFromFile, MultiScaleFlipAug
# yapf: enable # yapf: enable
from .builder import PIPELINES from mmdet3d.registry import TRANSFORMS
from mmdet.datasets.pipelines import MultiScaleFlipAug
def is_loading_function(transform): def is_loading_function(transform):
...@@ -35,7 +36,7 @@ def is_loading_function(transform): ...@@ -35,7 +36,7 @@ def is_loading_function(transform):
Collect3D, LoadImageFromFileMono3D, Collect3D, LoadImageFromFileMono3D,
PointSegClassMapping) PointSegClassMapping)
if isinstance(transform, dict): if isinstance(transform, dict):
obj_cls = PIPELINES.get(transform['type']) obj_cls = TRANSFORMS.get(transform['type'])
if obj_cls is None: if obj_cls is None:
return False return False
if obj_cls in loading_functions: if obj_cls in loading_functions:
......
...@@ -8,8 +8,8 @@ import numpy as np ...@@ -8,8 +8,8 @@ import numpy as np
import torch import torch
from mmcv.utils import print_log from mmcv.utils import print_log
from mmdet3d.registry import DATASETS
from ..core.bbox import Box3DMode, points_cam2img from ..core.bbox import Box3DMode, points_cam2img
from .builder import DATASETS
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
......
...@@ -3,10 +3,10 @@ from mmcv.runner import BaseModule, auto_fp16 ...@@ -3,10 +3,10 @@ from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule
from ..builder import BACKBONES from mmdet3d.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class DGCNNBackbone(BaseModule): class DGCNNBackbone(BaseModule):
"""Backbone network for DGCNN. """Backbone network for DGCNN.
......
...@@ -6,7 +6,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer ...@@ -6,7 +6,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn from torch import nn
from ..builder import BACKBONES from mmdet3d.registry import MODELS
def dla_build_norm_layer(cfg, num_features): def dla_build_norm_layer(cfg, num_features):
...@@ -275,7 +275,7 @@ class Tree(BaseModule): ...@@ -275,7 +275,7 @@ class Tree(BaseModule):
return x return x
@BACKBONES.register_module() @MODELS.register_module()
class DLANet(BaseModule): class DLANet(BaseModule):
r"""`DLA backbone <https://arxiv.org/abs/1707.06484>`_. r"""`DLA backbone <https://arxiv.org/abs/1707.06484>`_.
......
...@@ -7,10 +7,11 @@ from mmcv.cnn import ConvModule ...@@ -7,10 +7,11 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
from ..builder import BACKBONES, build_backbone from mmdet3d.models.builder import build_backbone
from mmdet3d.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class MultiBackbone(BaseModule): class MultiBackbone(BaseModule):
"""MultiBackbone with different configs. """MultiBackbone with different configs.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.registry import MODELS
from mmdet.models.backbones import RegNet from mmdet.models.backbones import RegNet
from ..builder import BACKBONES
@BACKBONES.register_module() @MODELS.register_module()
class NoStemRegNet(RegNet): class NoStemRegNet(RegNet):
"""RegNet backbone without Stem for 3D detection. """RegNet backbone without Stem for 3D detection.
......
...@@ -5,11 +5,11 @@ from mmcv.runner import auto_fp16 ...@@ -5,11 +5,11 @@ from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from ..builder import BACKBONES from mmdet3d.registry import MODELS
from .base_pointnet import BasePointNet from .base_pointnet import BasePointNet
@BACKBONES.register_module() @MODELS.register_module()
class PointNet2SAMSG(BasePointNet): class PointNet2SAMSG(BasePointNet):
"""PointNet2 with Multi-scale grouping. """PointNet2 with Multi-scale grouping.
......
...@@ -4,11 +4,11 @@ from mmcv.runner import auto_fp16 ...@@ -4,11 +4,11 @@ from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule, build_sa_module from mmdet3d.ops import PointFPModule, build_sa_module
from ..builder import BACKBONES from mmdet3d.registry import MODELS
from .base_pointnet import BasePointNet from .base_pointnet import BasePointNet
@BACKBONES.register_module() @MODELS.register_module()
class PointNet2SASSG(BasePointNet): class PointNet2SASSG(BasePointNet):
"""PointNet2 with Single-scale grouping. """PointNet2 with Single-scale grouping.
......
...@@ -5,10 +5,10 @@ from mmcv.cnn import build_conv_layer, build_norm_layer ...@@ -5,10 +5,10 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from ..builder import BACKBONES from mmdet3d.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class SECOND(BaseModule): class SECOND(BaseModule):
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet. """Backbone network for SECOND/PointPillars/PartA2/MVXNet.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from mmcv.cnn import MODELS as MMCV_MODELS from mmdet3d.registry import MODELS
from mmcv.utils import Registry
from mmdet.models.builder import BACKBONES as MMDET_BACKBONES
from mmdet.models.builder import DETECTORS as MMDET_DETECTORS
from mmdet.models.builder import HEADS as MMDET_HEADS
from mmdet.models.builder import LOSSES as MMDET_LOSSES
from mmdet.models.builder import NECKS as MMDET_NECKS
from mmdet.models.builder import ROI_EXTRACTORS as MMDET_ROI_EXTRACTORS
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
from mmseg.models.builder import LOSSES as MMSEG_LOSSES
MODELS = Registry('models', parent=MMCV_MODELS)
BACKBONES = MODELS BACKBONES = MODELS
NECKS = MODELS NECKS = MODELS
...@@ -22,6 +10,7 @@ SHARED_HEADS = MODELS ...@@ -22,6 +10,7 @@ SHARED_HEADS = MODELS
HEADS = MODELS HEADS = MODELS
LOSSES = MODELS LOSSES = MODELS
DETECTORS = MODELS DETECTORS = MODELS
SEGMENTORS = MODELS
VOXEL_ENCODERS = MODELS VOXEL_ENCODERS = MODELS
MIDDLE_ENCODERS = MODELS MIDDLE_ENCODERS = MODELS
FUSION_LAYERS = MODELS FUSION_LAYERS = MODELS
...@@ -30,52 +19,47 @@ SEGMENTORS = MODELS ...@@ -30,52 +19,47 @@ SEGMENTORS = MODELS
def build_backbone(cfg): def build_backbone(cfg):
"""Build backbone.""" """Build backbone."""
if cfg['type'] in BACKBONES._module_dict.keys(): warnings.warn('``build_backbone`` would be deprecated soon, please use '
return BACKBONES.build(cfg) '``mmdet3d.registry.MODELS.build()`` ')
else:
return MMDET_BACKBONES.build(cfg) return BACKBONES.build(cfg)
def build_neck(cfg): def build_neck(cfg):
"""Build neck.""" """Build neck."""
if cfg['type'] in NECKS._module_dict.keys(): warnings.warn('``build_neck`` would be deprecated soon, please use '
return NECKS.build(cfg) '``mmdet3d.registry.MODELS.build()`` ')
else:
return MMDET_NECKS.build(cfg) return NECKS.build(cfg)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
"""Build RoI feature extractor.""" """Build roi extractor."""
if cfg['type'] in ROI_EXTRACTORS._module_dict.keys(): warnings.warn(
return ROI_EXTRACTORS.build(cfg) '``build_roi_extractor`` would be deprecated soon, please use '
else: '``mmdet3d.registry.MODELS.build()`` ')
return MMDET_ROI_EXTRACTORS.build(cfg) return ROI_EXTRACTORS.build(cfg)
def build_shared_head(cfg): def build_shared_head(cfg):
"""Build shared head of detector.""" """Build shared head."""
if cfg['type'] in SHARED_HEADS._module_dict.keys(): warnings.warn('``build_shared_head`` would be deprecated soon, please use '
return SHARED_HEADS.build(cfg) '``mmdet3d.registry.MODELS.build()`` ')
else: return SHARED_HEADS.build(cfg)
return MMDET_SHARED_HEADS.build(cfg)
def build_head(cfg): def build_head(cfg):
"""Build head.""" """Build head."""
if cfg['type'] in HEADS._module_dict.keys(): warnings.warn('``build_head`` would be deprecated soon, please use '
return HEADS.build(cfg) '``mmdet3d.registry.MODELS.build()`` ')
else: return HEADS.build(cfg)
return MMDET_HEADS.build(cfg)
def build_loss(cfg): def build_loss(cfg):
"""Build loss function.""" """Build loss."""
if cfg['type'] in LOSSES._module_dict.keys(): warnings.warn('``build_loss`` would be deprecated soon, please use '
return LOSSES.build(cfg) '``mmdet3d.registry.MODELS.build()`` ')
elif cfg['type'] in MMDET_LOSSES._module_dict.keys(): return LOSSES.build(cfg)
return MMDET_LOSSES.build(cfg)
else:
return MMSEG_LOSSES.build(cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
...@@ -91,9 +75,6 @@ def build_detector(cfg, train_cfg=None, test_cfg=None): ...@@ -91,9 +75,6 @@ def build_detector(cfg, train_cfg=None, test_cfg=None):
if cfg['type'] in DETECTORS._module_dict.keys(): if cfg['type'] in DETECTORS._module_dict.keys():
return DETECTORS.build( return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
else:
return MMDET_DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_segmentor(cfg, train_cfg=None, test_cfg=None): def build_segmentor(cfg, train_cfg=None, test_cfg=None):
...@@ -124,14 +105,20 @@ def build_model(cfg, train_cfg=None, test_cfg=None): ...@@ -124,14 +105,20 @@ def build_model(cfg, train_cfg=None, test_cfg=None):
def build_voxel_encoder(cfg): def build_voxel_encoder(cfg):
"""Build voxel encoder.""" """Build voxel encoder."""
warnings.warn('``build_voxel_encoder`` would be deprecated soon, please '
'use ``mmdet3d.registry.MODELS.build()`` ')
return VOXEL_ENCODERS.build(cfg) return VOXEL_ENCODERS.build(cfg)
def build_middle_encoder(cfg): def build_middle_encoder(cfg):
"""Build middle level encoder.""" """Build middle level encoder."""
warnings.warn('``build_middle_encoder`` would be deprecated soon, please '
'use ``mmdet3d.registry.MODELS.build()`` ')
return MIDDLE_ENCODERS.build(cfg) return MIDDLE_ENCODERS.build(cfg)
def build_fusion_layer(cfg): def build_fusion_layer(cfg):
"""Build fusion layer.""" """Build fusion layer."""
warnings.warn('``build_fusion_layer`` would be deprecated soon, please '
'use ``mmdet3d.registry.MODELS.build()`` ')
return FUSION_LAYERS.build(cfg) return FUSION_LAYERS.build(cfg)
...@@ -5,7 +5,7 @@ from mmcv.cnn import normal_init ...@@ -5,7 +5,7 @@ from mmcv.cnn import normal_init
from mmcv.runner import BaseModule, auto_fp16, force_fp32 from mmcv.runner import BaseModule, auto_fp16, force_fp32
from torch import nn as nn from torch import nn as nn
from mmseg.models.builder import build_loss from ..builder import build_loss
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta): class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from mmdet3d.ops import DGCNNFPModule from mmdet3d.ops import DGCNNFPModule
from ..builder import HEADS from mmdet3d.registry import MODELS
from .decode_head import Base3DDecodeHead from .decode_head import Base3DDecodeHead
@HEADS.register_module() @MODELS.register_module()
class DGCNNHead(Base3DDecodeHead): class DGCNNHead(Base3DDecodeHead):
r"""DGCNN decoder head. r"""DGCNN decoder head.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from ..builder import HEADS from mmdet3d.registry import MODELS
from .pointnet2_head import PointNet2Head from .pointnet2_head import PointNet2Head
@HEADS.register_module() @MODELS.register_module()
class PAConvHead(PointNet2Head): class PAConvHead(PointNet2Head):
r"""PAConv decoder head. r"""PAConv decoder head.
......
...@@ -3,11 +3,11 @@ from mmcv.cnn.bricks import ConvModule ...@@ -3,11 +3,11 @@ from mmcv.cnn.bricks import ConvModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule from mmdet3d.ops import PointFPModule
from ..builder import HEADS from mmdet3d.registry import MODELS
from .decode_head import Base3DDecodeHead from .decode_head import Base3DDecodeHead
@HEADS.register_module() @MODELS.register_module()
class PointNet2Head(Base3DDecodeHead): class PointNet2Head(Base3DDecodeHead):
r"""PointNet2 decoder head. r"""PointNet2 decoder head.
......
...@@ -6,13 +6,14 @@ from torch import nn as nn ...@@ -6,13 +6,14 @@ from torch import nn as nn
from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period, from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
xywhr2xyxyr) xywhr2xyxyr)
from mmdet3d.registry import MODELS
from mmdet.core import (build_assigner, build_bbox_coder, from mmdet.core import (build_assigner, build_bbox_coder,
build_prior_generator, build_sampler, multi_apply) build_prior_generator, build_sampler, multi_apply)
from ..builder import HEADS, build_loss from ..builder import build_loss
from .train_mixins import AnchorTrainMixin from .train_mixins import AnchorTrainMixin
@HEADS.register_module() @MODELS.register_module()
class Anchor3DHead(BaseModule, AnchorTrainMixin): class Anchor3DHead(BaseModule, AnchorTrainMixin):
"""Anchor head for SECOND/PointPillars/MVXNet/PartA2. """Anchor head for SECOND/PointPillars/MVXNet/PartA2.
......
...@@ -6,12 +6,13 @@ from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init ...@@ -6,12 +6,13 @@ from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from ..builder import HEADS, build_loss from ..builder import build_loss
from .base_mono3d_dense_head import BaseMono3DDenseHead from .base_mono3d_dense_head import BaseMono3DDenseHead
@HEADS.register_module() @MODELS.register_module()
class AnchorFreeMono3DHead(BaseMono3DDenseHead): class AnchorFreeMono3DHead(BaseMono3DDenseHead):
"""Anchor-free head for monocular 3D object detection. """Anchor-free head for monocular 3D object detection.
......
...@@ -4,10 +4,10 @@ from mmcv.cnn.bricks import build_conv_layer ...@@ -4,10 +4,10 @@ from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from ..builder import HEADS from mmdet3d.registry import MODELS
@HEADS.register_module() @MODELS.register_module()
class BaseConvBboxHead(BaseModule): class BaseConvBboxHead(BaseModule):
r"""More general bbox head, with shared conv layers and two optional r"""More general bbox head, with shared conv layers and two optional
separated branches. separated branches.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment