Commit e3b5253b authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

Update all registries and fix some ut problems

parent 8dd8da12
...@@ -11,13 +11,12 @@ from torch.nn import functional as F ...@@ -11,13 +11,12 @@ from torch.nn import functional as F
from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result, from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
merge_aug_bboxes_3d, show_result) merge_aug_bboxes_3d, show_result)
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from .. import builder
from ..builder import DETECTORS
from .base import Base3DDetector from .base import Base3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class MVXTwoStageDetector(Base3DDetector): class MVXTwoStageDetector(Base3DDetector):
"""Base class of Multi-modality VoxelNet.""" """Base class of Multi-modality VoxelNet."""
...@@ -42,33 +41,30 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -42,33 +41,30 @@ class MVXTwoStageDetector(Base3DDetector):
if pts_voxel_layer: if pts_voxel_layer:
self.pts_voxel_layer = Voxelization(**pts_voxel_layer) self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
if pts_voxel_encoder: if pts_voxel_encoder:
self.pts_voxel_encoder = builder.build_voxel_encoder( self.pts_voxel_encoder = MODELS.build(pts_voxel_encoder)
pts_voxel_encoder)
if pts_middle_encoder: if pts_middle_encoder:
self.pts_middle_encoder = builder.build_middle_encoder( self.pts_middle_encoder = MODELS.build(pts_middle_encoder)
pts_middle_encoder)
if pts_backbone: if pts_backbone:
self.pts_backbone = builder.build_backbone(pts_backbone) self.pts_backbone = MODELS.build(pts_backbone)
if pts_fusion_layer: if pts_fusion_layer:
self.pts_fusion_layer = builder.build_fusion_layer( self.pts_fusion_layer = MODELS.build(pts_fusion_layer)
pts_fusion_layer)
if pts_neck is not None: if pts_neck is not None:
self.pts_neck = builder.build_neck(pts_neck) self.pts_neck = MODELS.build(pts_neck)
if pts_bbox_head: if pts_bbox_head:
pts_train_cfg = train_cfg.pts if train_cfg else None pts_train_cfg = train_cfg.pts if train_cfg else None
pts_bbox_head.update(train_cfg=pts_train_cfg) pts_bbox_head.update(train_cfg=pts_train_cfg)
pts_test_cfg = test_cfg.pts if test_cfg else None pts_test_cfg = test_cfg.pts if test_cfg else None
pts_bbox_head.update(test_cfg=pts_test_cfg) pts_bbox_head.update(test_cfg=pts_test_cfg)
self.pts_bbox_head = builder.build_head(pts_bbox_head) self.pts_bbox_head = MODELS.build(pts_bbox_head)
if img_backbone: if img_backbone:
self.img_backbone = builder.build_backbone(img_backbone) self.img_backbone = MODELS.build(img_backbone)
if img_neck is not None: if img_neck is not None:
self.img_neck = builder.build_neck(img_neck) self.img_neck = MODELS.build(img_neck)
if img_rpn_head is not None: if img_rpn_head is not None:
self.img_rpn_head = builder.build_head(img_rpn_head) self.img_rpn_head = MODELS.build(img_rpn_head)
if img_roi_head is not None: if img_roi_head is not None:
self.img_roi_head = builder.build_head(img_roi_head) self.img_roi_head = MODELS.build(img_roi_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
......
...@@ -3,12 +3,11 @@ import torch ...@@ -3,12 +3,11 @@ import torch
from mmcv.ops import Voxelization from mmcv.ops import Voxelization
from torch.nn import functional as F from torch.nn import functional as F
from .. import builder from mmdet3d.registry import MODELS
from ..builder import DETECTORS
from .two_stage import TwoStage3DDetector from .two_stage import TwoStage3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class PartA2(TwoStage3DDetector): class PartA2(TwoStage3DDetector):
r"""Part-A2 detector. r"""Part-A2 detector.
...@@ -37,8 +36,8 @@ class PartA2(TwoStage3DDetector): ...@@ -37,8 +36,8 @@ class PartA2(TwoStage3DDetector):
pretrained=pretrained, pretrained=pretrained,
init_cfg=init_cfg) init_cfg=init_cfg)
self.voxel_layer = Voxelization(**voxel_layer) self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = MODELS.build(middle_encoder)
def extract_feat(self, points, img_metas): def extract_feat(self, points, img_metas):
"""Extract features from points.""" """Extract features from points."""
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .two_stage import TwoStage3DDetector from .two_stage import TwoStage3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class PointRCNN(TwoStage3DDetector): class PointRCNN(TwoStage3DDetector):
r"""PointRCNN detector. r"""PointRCNN detector.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS, build_backbone, build_head, build_neck from mmdet3d.registry import MODELS
from .base import Base3DDetector from .base import Base3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class SingleStage3DDetector(Base3DDetector): class SingleStage3DDetector(Base3DDetector):
"""SingleStage3DDetector. """SingleStage3DDetector.
...@@ -30,12 +30,12 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -30,12 +30,12 @@ class SingleStage3DDetector(Base3DDetector):
init_cfg=None, init_cfg=None,
pretrained=None): pretrained=None):
super(SingleStage3DDetector, self).__init__(init_cfg) super(SingleStage3DDetector, self).__init__(init_cfg)
self.backbone = build_backbone(backbone) self.backbone = MODELS.build(backbone)
if neck is not None: if neck is not None:
self.neck = build_neck(neck) self.neck = MODELS.build(neck)
bbox_head.update(train_cfg=train_cfg) bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg) bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head) self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
from os import path as osp from os import path as osp
import mmcv import mmcv
...@@ -9,11 +8,11 @@ from mmcv.parallel import DataContainer as DC ...@@ -9,11 +8,11 @@ from mmcv.parallel import DataContainer as DC
from mmdet3d.core import (CameraInstance3DBoxes, bbox3d2result, from mmdet3d.core import (CameraInstance3DBoxes, bbox3d2result,
show_multi_modality_result) show_multi_modality_result)
from mmdet.models.detectors import SingleStageDetector from mmdet3d.registry import MODELS
from ..builder import DETECTORS, build_backbone, build_head, build_neck from mmdet.models.detectors.single_stage import SingleStageDetector
@DETECTORS.register_module() @MODELS.register_module()
class SingleStageMono3DDetector(SingleStageDetector): class SingleStageMono3DDetector(SingleStageDetector):
"""Base class for monocular 3D single-stage detectors. """Base class for monocular 3D single-stage detectors.
...@@ -21,28 +20,6 @@ class SingleStageMono3DDetector(SingleStageDetector): ...@@ -21,28 +20,6 @@ class SingleStageMono3DDetector(SingleStageDetector):
output features of the backbone+neck. output features of the backbone+neck.
""" """
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(SingleStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def extract_feats(self, imgs): def extract_feats(self, imgs):
"""Directly extract features from the backbone+neck.""" """Directly extract features from the backbone+neck."""
assert isinstance(imgs, list) assert isinstance(imgs, list)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .single_stage_mono3d import SingleStageMono3DDetector from .single_stage_mono3d import SingleStageMono3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class SMOKEMono3D(SingleStageMono3DDetector): class SMOKEMono3D(SingleStageMono3DDetector):
r"""SMOKE <https://arxiv.org/abs/2002.10111>`_ for monocular 3D object r"""SMOKE <https://arxiv.org/abs/2002.10111>`_ for monocular 3D object
detection. detection.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .votenet import VoteNet from .votenet import VoteNet
@DETECTORS.register_module() @MODELS.register_module()
class SSD3DNet(VoteNet): class SSD3DNet(VoteNet):
"""3DSSDNet model. """3DSSDNet model.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings from mmdet3d.registry import MODELS
from mmdet.models import TwoStageDetector from mmdet.models import TwoStageDetector
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector from .base import Base3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class TwoStage3DDetector(Base3DDetector, TwoStageDetector): class TwoStage3DDetector(Base3DDetector, TwoStageDetector):
"""Base class of two-stage 3D detector. """Base class of two-stage 3D detector.
...@@ -15,37 +13,5 @@ class TwoStage3DDetector(Base3DDetector, TwoStageDetector): ...@@ -15,37 +13,5 @@ class TwoStage3DDetector(Base3DDetector, TwoStageDetector):
two-stage 3D detectors. two-stage 3D detectors.
""" """
def __init__(self, def __init__(self, **kwargs):
backbone, super(TwoStage3DDetector, self).__init__(**kwargs)
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(TwoStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if neck is not None:
self.neck = build_neck(neck)
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head_ = rpn_head.copy()
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head_)
if roi_head is not None:
# update train and test cfg here for now
# TODO: refactor assigner & sampler
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
roi_head.update(train_cfg=rcnn_train_cfg)
roi_head.update(test_cfg=test_cfg.rcnn)
roi_head.pretrained = pretrained
self.roi_head = build_head(roi_head)
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
import torch import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .single_stage import SingleStage3DDetector from .single_stage import SingleStage3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class VoteNet(SingleStage3DDetector): class VoteNet(SingleStage3DDetector):
r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection.""" r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection."""
......
...@@ -5,12 +5,11 @@ from mmcv.runner import force_fp32 ...@@ -5,12 +5,11 @@ from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from .. import builder from mmdet3d.registry import MODELS
from ..builder import DETECTORS
from .single_stage import SingleStage3DDetector from .single_stage import SingleStage3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class VoxelNet(SingleStage3DDetector): class VoxelNet(SingleStage3DDetector):
r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection.""" r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
...@@ -34,8 +33,8 @@ class VoxelNet(SingleStage3DDetector): ...@@ -34,8 +33,8 @@ class VoxelNet(SingleStage3DDetector):
init_cfg=init_cfg, init_cfg=init_cfg,
pretrained=pretrained) pretrained=pretrained)
self.voxel_layer = Voxelization(**voxel_layer) self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = MODELS.build(middle_encoder)
def extract_feat(self, points, img_metas=None): def extract_feat(self, points, img_metas=None):
"""Extract features from points.""" """Extract features from points."""
......
...@@ -7,7 +7,7 @@ from torch.nn import functional as F ...@@ -7,7 +7,7 @@ from torch.nn import functional as F
from mmdet3d.core.bbox.structures import (get_proj_mat_by_coord_type, from mmdet3d.core.bbox.structures import (get_proj_mat_by_coord_type,
points_cam2img) points_cam2img)
from ..builder import FUSION_LAYERS from mmdet3d.registry import MODELS
from . import apply_3d_transformation from . import apply_3d_transformation
...@@ -91,7 +91,7 @@ def point_sample(img_meta, ...@@ -91,7 +91,7 @@ def point_sample(img_meta,
return point_features.squeeze().t() return point_features.squeeze().t()
@FUSION_LAYERS.register_module() @MODELS.register_module()
class PointFusion(BaseModule): class PointFusion(BaseModule):
"""Fuse image features from multi-scale features. """Fuse image features from multi-scale features.
......
...@@ -3,13 +3,13 @@ import torch ...@@ -3,13 +3,13 @@ import torch
from torch import nn as nn from torch import nn as nn
from mmdet3d.core.bbox import points_cam2img from mmdet3d.core.bbox import points_cam2img
from ..builder import FUSION_LAYERS from mmdet3d.registry import MODELS
from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform
EPS = 1e-6 EPS = 1e-6
@FUSION_LAYERS.register_module() @MODELS.register_module()
class VoteFusion(nn.Module): class VoteFusion(nn.Module):
"""Fuse 2d features from 3d seeds. """Fuse 2d features from 3d seeds.
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D from ...core.bbox import AxisAlignedBboxOverlaps3D
from ..builder import LOSSES
@weighted_loss @weighted_loss
...@@ -26,7 +26,7 @@ def axis_aligned_iou_loss(pred, target): ...@@ -26,7 +26,7 @@ def axis_aligned_iou_loss(pred, target):
return iou_loss return iou_loss
@LOSSES.register_module() @MODELS.register_module()
class AxisAlignedIoULoss(nn.Module): class AxisAlignedIoULoss(nn.Module):
"""Calculate the IoU loss (1-IoU) of axis aligned bounding boxes. """Calculate the IoU loss (1-IoU) of axis aligned bounding boxes.
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from ..builder import LOSSES from mmdet3d.registry import MODELS
def chamfer_distance(src, def chamfer_distance(src,
...@@ -72,7 +72,7 @@ def chamfer_distance(src, ...@@ -72,7 +72,7 @@ def chamfer_distance(src,
return loss_src, loss_dst, indices1, indices2 return loss_src, loss_dst, indices1, indices2
@LOSSES.register_module() @MODELS.register_module()
class ChamferDistance(nn.Module): class ChamferDistance(nn.Module):
"""Calculate Chamfer Distance of two sets. """Calculate Chamfer Distance of two sets.
......
...@@ -3,8 +3,8 @@ import torch ...@@ -3,8 +3,8 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES
@weighted_loss @weighted_loss
...@@ -57,7 +57,7 @@ def multibin_loss(pred_orientations, gt_orientations, num_dir_bins=4): ...@@ -57,7 +57,7 @@ def multibin_loss(pred_orientations, gt_orientations, num_dir_bins=4):
return cls_losses / num_dir_bins + reg_losses / reg_cnt return cls_losses / num_dir_bins + reg_losses / reg_cnt
@LOSSES.register_module() @MODELS.register_module()
class MultiBinLoss(nn.Module): class MultiBinLoss(nn.Module):
"""Multi-Bin Loss for orientation. """Multi-Bin Loss for orientation.
......
...@@ -3,8 +3,8 @@ import torch ...@@ -3,8 +3,8 @@ import torch
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PAConv, PAConvCUDA from mmdet3d.ops import PAConv, PAConvCUDA
from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weight_reduce_loss from mmdet.models.losses.utils import weight_reduce_loss
from ..builder import LOSSES
def weight_correlation(conv): def weight_correlation(conv):
...@@ -68,7 +68,7 @@ def paconv_regularization_loss(modules, reduction): ...@@ -68,7 +68,7 @@ def paconv_regularization_loss(modules, reduction):
return corr_loss return corr_loss
@LOSSES.register_module() @MODELS.register_module()
class PAConvRegularizationLoss(nn.Module): class PAConvRegularizationLoss(nn.Module):
"""Calculate correlation loss of kernel weights in PAConv's weight bank. """Calculate correlation loss of kernel weights in PAConv's weight bank.
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES
@weighted_loss @weighted_loss
...@@ -58,7 +58,7 @@ def uncertain_l1_loss(pred, target, sigma, alpha=1.0): ...@@ -58,7 +58,7 @@ def uncertain_l1_loss(pred, target, sigma, alpha=1.0):
return loss return loss
@LOSSES.register_module() @MODELS.register_module()
class UncertainSmoothL1Loss(nn.Module): class UncertainSmoothL1Loss(nn.Module):
r"""Smooth L1 loss with uncertainty. r"""Smooth L1 loss with uncertainty.
...@@ -122,7 +122,7 @@ class UncertainSmoothL1Loss(nn.Module): ...@@ -122,7 +122,7 @@ class UncertainSmoothL1Loss(nn.Module):
return loss_bbox return loss_bbox
@LOSSES.register_module() @MODELS.register_module()
class UncertainL1Loss(nn.Module): class UncertainL1Loss(nn.Module):
"""L1 loss with uncertainty. """L1 loss with uncertainty.
......
...@@ -3,10 +3,10 @@ import torch ...@@ -3,10 +3,10 @@ import torch
from mmcv.runner import auto_fp16 from mmcv.runner import auto_fp16
from torch import nn from torch import nn
from ..builder import MIDDLE_ENCODERS from mmdet3d.registry import MODELS
@MIDDLE_ENCODERS.register_module() @MODELS.register_module()
class PointPillarsScatter(nn.Module): class PointPillarsScatter(nn.Module):
"""Point Pillar's Scatter. """Point Pillar's Scatter.
......
...@@ -6,8 +6,7 @@ from torch import nn as nn ...@@ -6,8 +6,7 @@ from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE
from mmdet.models.losses import sigmoid_focal_loss, smooth_l1_loss from mmdet3d.registry import MODELS
from ..builder import MIDDLE_ENCODERS
if IS_SPCONV2_AVAILABLE: if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import SparseConvTensor, SparseSequential from spconv.pytorch import SparseConvTensor, SparseSequential
...@@ -15,7 +14,7 @@ else: ...@@ -15,7 +14,7 @@ else:
from mmcv.ops import SparseConvTensor, SparseSequential from mmcv.ops import SparseConvTensor, SparseSequential
@MIDDLE_ENCODERS.register_module() @MODELS.register_module()
class SparseEncoder(nn.Module): class SparseEncoder(nn.Module):
r"""Sparse encoder for SECOND and Part-A2. r"""Sparse encoder for SECOND and Part-A2.
......
...@@ -15,7 +15,7 @@ from mmdet3d.ops.sparse_block import replace_feature ...@@ -15,7 +15,7 @@ from mmdet3d.ops.sparse_block import replace_feature
from ..builder import MIDDLE_ENCODERS from ..builder import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module() @MODELS.register_module()
class SparseUNet(BaseModule): class SparseUNet(BaseModule):
r"""SparseUNet for PartA^2. r"""SparseUNet for PartA^2.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment