Unverified Commit 0287048a authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Enhance] Update Registry in MMDet3D (#1412)

* Update Registry in MMDet3D

* fix compose pipeline bug

* update registry

* fix some bugs

* fix comments

* fix comments
parent e013bab5
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from mmdet3d.core import show_multi_modality_result, show_result from mmdet3d.core import show_multi_modality_result, show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet.core import eval_map from mmdet.core import eval_map
from mmdet.datasets import DATASETS from .builder import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .pipelines import Compose from .pipelines import Compose
......
...@@ -10,9 +10,9 @@ from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D, ...@@ -10,9 +10,9 @@ from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
LoadPointsFromMultiSweeps, LoadPointsFromMultiSweeps,
MultiScaleFlipAug3D, MultiScaleFlipAug3D,
PointSegClassMapping) PointSegClassMapping)
# yapf: enable
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadImageFromFile, MultiScaleFlipAug from mmdet.datasets.pipelines import LoadImageFromFile, MultiScaleFlipAug
# yapf: enable
from .builder import PIPELINES
def is_loading_function(transform): def is_loading_function(transform):
......
...@@ -8,8 +8,8 @@ import numpy as np ...@@ -8,8 +8,8 @@ import numpy as np
import torch import torch
from mmcv.utils import print_log from mmcv.utils import print_log
from mmdet.datasets import DATASETS
from ..core.bbox import Box3DMode, points_cam2img from ..core.bbox import Box3DMode, points_cam2img
from .builder import DATASETS
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403 from .backbones import * # noqa: F401,F403
from .builder import (FUSION_LAYERS, MIDDLE_ENCODERS, VOXEL_ENCODERS, from .builder import (BACKBONES, DETECTORS, FUSION_LAYERS, HEADS, LOSSES,
build_backbone, build_detector, build_fusion_layer, MIDDLE_ENCODERS, NECKS, ROI_EXTRACTORS, SEGMENTORS,
build_head, build_loss, build_middle_encoder, SHARED_HEADS, VOXEL_ENCODERS, build_backbone,
build_model, build_neck, build_roi_extractor, build_detector, build_fusion_layer, build_head,
build_shared_head, build_voxel_encoder) build_loss, build_middle_encoder, build_model,
build_neck, build_roi_extractor, build_shared_head,
build_voxel_encoder)
from .decode_heads import * # noqa: F401,F403 from .decode_heads import * # noqa: F401,F403
from .dense_heads import * # noqa: F401,F403 from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403 from .detectors import * # noqa: F401,F403
...@@ -18,8 +20,10 @@ from .segmentors import * # noqa: F401,F403 ...@@ -18,8 +20,10 @@ from .segmentors import * # noqa: F401,F403
from .voxel_encoders import * # noqa: F401,F403 from .voxel_encoders import * # noqa: F401,F403
__all__ = [ __all__ = [
'VOXEL_ENCODERS', 'MIDDLE_ENCODERS', 'FUSION_LAYERS', 'build_backbone', 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
'build_neck', 'build_roi_extractor', 'build_shared_head', 'build_head', 'DETECTORS', 'SEGMENTORS', 'VOXEL_ENCODERS', 'MIDDLE_ENCODERS',
'build_loss', 'build_detector', 'build_fusion_layer', 'build_model', 'FUSION_LAYERS', 'build_backbone', 'build_neck', 'build_roi_extractor',
'build_middle_encoder', 'build_voxel_encoder' 'build_shared_head', 'build_head', 'build_loss', 'build_detector',
'build_fusion_layer', 'build_model', 'build_middle_encoder',
'build_voxel_encoder'
] ]
...@@ -3,7 +3,7 @@ from mmcv.runner import BaseModule, auto_fp16 ...@@ -3,7 +3,7 @@ from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule
from mmdet.models import BACKBONES from ..builder import BACKBONES
@BACKBONES.register_module() @BACKBONES.register_module()
......
...@@ -6,7 +6,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer ...@@ -6,7 +6,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn from torch import nn
from mmdet.models.builder import BACKBONES from ..builder import BACKBONES
def dla_build_norm_layer(cfg, num_features): def dla_build_norm_layer(cfg, num_features):
......
...@@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule ...@@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet.models import BACKBONES, build_backbone from ..builder import BACKBONES, build_backbone
@BACKBONES.register_module() @BACKBONES.register_module()
......
...@@ -5,7 +5,7 @@ from mmcv.runner import auto_fp16 ...@@ -5,7 +5,7 @@ from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet.models import BACKBONES from ..builder import BACKBONES
from .base_pointnet import BasePointNet from .base_pointnet import BasePointNet
......
...@@ -4,7 +4,7 @@ from mmcv.runner import auto_fp16 ...@@ -4,7 +4,7 @@ from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule, build_sa_module from mmdet3d.ops import PointFPModule, build_sa_module
from mmdet.models import BACKBONES from ..builder import BACKBONES
from .base_pointnet import BasePointNet from .base_pointnet import BasePointNet
......
...@@ -5,7 +5,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer ...@@ -5,7 +5,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet.models import BACKBONES from ..builder import BACKBONES
@BACKBONES.register_module() @BACKBONES.register_module()
......
...@@ -4,45 +4,78 @@ import warnings ...@@ -4,45 +4,78 @@ import warnings
from mmcv.cnn import MODELS as MMCV_MODELS from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry from mmcv.utils import Registry
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS, from mmdet.models.builder import BACKBONES as MMDET_BACKBONES
ROI_EXTRACTORS, SHARED_HEADS) from mmdet.models.builder import DETECTORS as MMDET_DETECTORS
from mmseg.models.builder import SEGMENTORS from mmdet.models.builder import HEADS as MMDET_HEADS
from mmdet.models.builder import LOSSES as MMDET_LOSSES
from mmdet.models.builder import NECKS as MMDET_NECKS
from mmdet.models.builder import ROI_EXTRACTORS as MMDET_ROI_EXTRACTORS
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
from mmseg.models.builder import LOSSES as MMSEG_LOSSES
MODELS = Registry('models', parent=MMCV_MODELS) MODELS = Registry('models', parent=MMCV_MODELS)
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS
DETECTORS = MODELS
VOXEL_ENCODERS = MODELS VOXEL_ENCODERS = MODELS
MIDDLE_ENCODERS = MODELS MIDDLE_ENCODERS = MODELS
FUSION_LAYERS = MODELS FUSION_LAYERS = MODELS
SEGMENTORS = MODELS
def build_backbone(cfg): def build_backbone(cfg):
"""Build backbone.""" """Build backbone."""
return BACKBONES.build(cfg) if cfg['type'] in BACKBONES._module_dict.keys():
return BACKBONES.build(cfg)
else:
return MMDET_BACKBONES.build(cfg)
def build_neck(cfg): def build_neck(cfg):
"""Build neck.""" """Build neck."""
return NECKS.build(cfg) if cfg['type'] in NECKS._module_dict.keys():
return NECKS.build(cfg)
else:
return MMDET_NECKS.build(cfg)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
"""Build RoI feature extractor.""" """Build RoI feature extractor."""
return ROI_EXTRACTORS.build(cfg) if cfg['type'] in NECKS._module_dict.keys():
return ROI_EXTRACTORS.build(cfg)
else:
return MMDET_ROI_EXTRACTORS.build(cfg)
def build_shared_head(cfg): def build_shared_head(cfg):
"""Build shared head of detector.""" """Build shared head of detector."""
return SHARED_HEADS.build(cfg) if cfg['type'] in SHARED_HEADS._module_dict.keys():
return SHARED_HEADS.build(cfg)
else:
return MMDET_SHARED_HEADS.build(cfg)
def build_head(cfg): def build_head(cfg):
"""Build head.""" """Build head."""
return HEADS.build(cfg) if cfg['type'] in HEADS._module_dict.keys():
return HEADS.build(cfg)
else:
return MMDET_HEADS.build(cfg)
def build_loss(cfg): def build_loss(cfg):
"""Build loss function.""" """Build loss function."""
return LOSSES.build(cfg) if cfg['type'] in LOSSES._module_dict.keys():
return LOSSES.build(cfg)
elif cfg['type'] in MMDET_LOSSES._module_dict.keys():
return MMDET_LOSSES.build(cfg)
else:
return MMSEG_LOSSES.build(cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
...@@ -55,8 +88,12 @@ def build_detector(cfg, train_cfg=None, test_cfg=None): ...@@ -55,8 +88,12 @@ def build_detector(cfg, train_cfg=None, test_cfg=None):
'train_cfg specified in both outer field and model field ' 'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \ assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field ' 'test_cfg specified in both outer field and model field '
return DETECTORS.build( if cfg['type'] in DETECTORS._module_dict.keys():
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
else:
return MMDET_DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_segmentor(cfg, train_cfg=None, test_cfg=None): def build_segmentor(cfg, train_cfg=None, test_cfg=None):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from mmdet3d.ops import DGCNNFPModule from mmdet3d.ops import DGCNNFPModule
from mmdet.models import HEADS from ..builder import HEADS
from .decode_head import Base3DDecodeHead from .decode_head import Base3DDecodeHead
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn.bricks import ConvModule from mmcv.cnn.bricks import ConvModule
from mmdet.models import HEADS from ..builder import HEADS
from .pointnet2_head import PointNet2Head from .pointnet2_head import PointNet2Head
......
...@@ -3,7 +3,7 @@ from mmcv.cnn.bricks import ConvModule ...@@ -3,7 +3,7 @@ from mmcv.cnn.bricks import ConvModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule from mmdet3d.ops import PointFPModule
from mmdet.models import HEADS from ..builder import HEADS
from .decode_head import Base3DDecodeHead from .decode_head import Base3DDecodeHead
......
...@@ -8,8 +8,7 @@ from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period, ...@@ -8,8 +8,7 @@ from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
xywhr2xyxyr) xywhr2xyxyr)
from mmdet.core import (build_assigner, build_bbox_coder, from mmdet.core import (build_assigner, build_bbox_coder,
build_prior_generator, build_sampler, multi_apply) build_prior_generator, build_sampler, multi_apply)
from mmdet.models import HEADS from ..builder import HEADS, build_loss
from ..builder import build_loss
from .train_mixins import AnchorTrainMixin from .train_mixins import AnchorTrainMixin
......
...@@ -7,7 +7,7 @@ from mmcv.runner import force_fp32 ...@@ -7,7 +7,7 @@ from mmcv.runner import force_fp32
from torch import nn as nn from torch import nn as nn
from mmdet.core import multi_apply from mmdet.core import multi_apply
from mmdet.models.builder import HEADS, build_loss from ..builder import HEADS, build_loss
from .base_mono3d_dense_head import BaseMono3DDenseHead from .base_mono3d_dense_head import BaseMono3DDenseHead
......
...@@ -4,7 +4,7 @@ from mmcv.cnn.bricks import build_conv_layer ...@@ -4,7 +4,7 @@ from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet.models.builder import HEADS from ..builder import HEADS
@HEADS.register_module() @HEADS.register_module()
......
...@@ -10,9 +10,9 @@ from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius, ...@@ -10,9 +10,9 @@ from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
xywhr2xyxyr) xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev from mmdet3d.core.post_processing import nms_bev
from mmdet3d.models import builder from mmdet3d.models import builder
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.models.utils import clip_sigmoid from mmdet3d.models.utils import clip_sigmoid
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from ..builder import HEADS, build_loss
@HEADS.register_module() @HEADS.register_module()
......
...@@ -11,7 +11,7 @@ from mmdet3d.core import (box3d_multiclass_nms, limit_period, points_img2cam, ...@@ -11,7 +11,7 @@ from mmdet3d.core import (box3d_multiclass_nms, limit_period, points_img2cam,
xywhr2xyxyr) xywhr2xyxyr)
from mmdet.core import multi_apply from mmdet.core import multi_apply
from mmdet.core.bbox.builder import build_bbox_coder from mmdet.core.bbox.builder import build_bbox_coder
from mmdet.models.builder import HEADS, build_loss from ..builder import HEADS, build_loss
from .anchor_free_mono3d_head import AnchorFreeMono3DHead from .anchor_free_mono3d_head import AnchorFreeMono3DHead
INF = 1e8 INF = 1e8
......
...@@ -4,7 +4,7 @@ from mmcv.runner import force_fp32 ...@@ -4,7 +4,7 @@ from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox import bbox_overlaps_nearest_3d from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
from mmdet.models import HEADS from ..builder import HEADS
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
from .train_mixins import get_direction_target from .train_mixins import get_direction_target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment