Commit e3b5253b authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

Update all registries and fix some ut problems

parent 8dd8da12
......@@ -11,13 +11,12 @@ from torch.nn import functional as F
from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
merge_aug_bboxes_3d, show_result)
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply
from .. import builder
from ..builder import DETECTORS
from .base import Base3DDetector
@DETECTORS.register_module()
@MODELS.register_module()
class MVXTwoStageDetector(Base3DDetector):
"""Base class of Multi-modality VoxelNet."""
......@@ -42,33 +41,30 @@ class MVXTwoStageDetector(Base3DDetector):
if pts_voxel_layer:
self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
if pts_voxel_encoder:
self.pts_voxel_encoder = builder.build_voxel_encoder(
pts_voxel_encoder)
self.pts_voxel_encoder = MODELS.build(pts_voxel_encoder)
if pts_middle_encoder:
self.pts_middle_encoder = builder.build_middle_encoder(
pts_middle_encoder)
self.pts_middle_encoder = MODELS.build(pts_middle_encoder)
if pts_backbone:
self.pts_backbone = builder.build_backbone(pts_backbone)
self.pts_backbone = MODELS.build(pts_backbone)
if pts_fusion_layer:
self.pts_fusion_layer = builder.build_fusion_layer(
pts_fusion_layer)
self.pts_fusion_layer = MODELS.build(pts_fusion_layer)
if pts_neck is not None:
self.pts_neck = builder.build_neck(pts_neck)
self.pts_neck = MODELS.build(pts_neck)
if pts_bbox_head:
pts_train_cfg = train_cfg.pts if train_cfg else None
pts_bbox_head.update(train_cfg=pts_train_cfg)
pts_test_cfg = test_cfg.pts if test_cfg else None
pts_bbox_head.update(test_cfg=pts_test_cfg)
self.pts_bbox_head = builder.build_head(pts_bbox_head)
self.pts_bbox_head = MODELS.build(pts_bbox_head)
if img_backbone:
self.img_backbone = builder.build_backbone(img_backbone)
self.img_backbone = MODELS.build(img_backbone)
if img_neck is not None:
self.img_neck = builder.build_neck(img_neck)
self.img_neck = MODELS.build(img_neck)
if img_rpn_head is not None:
self.img_rpn_head = builder.build_head(img_rpn_head)
self.img_rpn_head = MODELS.build(img_rpn_head)
if img_roi_head is not None:
self.img_roi_head = builder.build_head(img_roi_head)
self.img_roi_head = MODELS.build(img_roi_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......
......@@ -3,12 +3,11 @@ import torch
from mmcv.ops import Voxelization
from torch.nn import functional as F
from .. import builder
from ..builder import DETECTORS
from mmdet3d.registry import MODELS
from .two_stage import TwoStage3DDetector
@DETECTORS.register_module()
@MODELS.register_module()
class PartA2(TwoStage3DDetector):
r"""Part-A2 detector.
......@@ -37,8 +36,8 @@ class PartA2(TwoStage3DDetector):
pretrained=pretrained,
init_cfg=init_cfg)
self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder)
self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = MODELS.build(middle_encoder)
def extract_feat(self, points, img_metas):
"""Extract features from points."""
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..builder import DETECTORS
from mmdet3d.registry import MODELS
from .two_stage import TwoStage3DDetector
@DETECTORS.register_module()
@MODELS.register_module()
class PointRCNN(TwoStage3DDetector):
r"""PointRCNN detector.
......
# Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from mmdet3d.registry import MODELS
from .base import Base3DDetector
@DETECTORS.register_module()
@MODELS.register_module()
class SingleStage3DDetector(Base3DDetector):
"""SingleStage3DDetector.
......@@ -30,12 +30,12 @@ class SingleStage3DDetector(Base3DDetector):
init_cfg=None,
pretrained=None):
super(SingleStage3DDetector, self).__init__(init_cfg)
self.backbone = build_backbone(backbone)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = build_neck(neck)
self.neck = MODELS.build(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from os import path as osp
import mmcv
......@@ -9,11 +8,11 @@ from mmcv.parallel import DataContainer as DC
from mmdet3d.core import (CameraInstance3DBoxes, bbox3d2result,
show_multi_modality_result)
from mmdet.models.detectors import SingleStageDetector
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from mmdet3d.registry import MODELS
from mmdet.models.detectors.single_stage import SingleStageDetector
@DETECTORS.register_module()
@MODELS.register_module()
class SingleStageMono3DDetector(SingleStageDetector):
"""Base class for monocular 3D single-stage detectors.
......@@ -21,28 +20,6 @@ class SingleStageMono3DDetector(SingleStageDetector):
output features of the backbone+neck.
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(SingleStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def extract_feats(self, imgs):
"""Directly extract features from the backbone+neck."""
assert isinstance(imgs, list)
......
# Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS
from mmdet3d.registry import MODELS
from .single_stage_mono3d import SingleStageMono3DDetector
@DETECTORS.register_module()
@MODELS.register_module()
class SMOKEMono3D(SingleStageMono3DDetector):
r"""SMOKE <https://arxiv.org/abs/2002.10111>`_ for monocular 3D object
detection.
......
# Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS
from mmdet3d.registry import MODELS
from .votenet import VoteNet
@DETECTORS.register_module()
@MODELS.register_module()
class SSD3DNet(VoteNet):
"""3DSSDNet model.
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmdet3d.registry import MODELS
from mmdet.models import TwoStageDetector
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector
@DETECTORS.register_module()
@MODELS.register_module()
class TwoStage3DDetector(Base3DDetector, TwoStageDetector):
"""Base class of two-stage 3D detector.
......@@ -15,37 +13,5 @@ class TwoStage3DDetector(Base3DDetector, TwoStageDetector):
two-stage 3D detectors.
"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(TwoStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if neck is not None:
self.neck = build_neck(neck)
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head_ = rpn_head.copy()
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head_)
if roi_head is not None:
# update train and test cfg here for now
# TODO: refactor assigner & sampler
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
roi_head.update(train_cfg=rcnn_train_cfg)
roi_head.update(test_cfg=test_cfg.rcnn)
roi_head.pretrained = pretrained
self.roi_head = build_head(roi_head)
def __init__(self, **kwargs):
super(TwoStage3DDetector, self).__init__(**kwargs)
......@@ -2,11 +2,11 @@
import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from ..builder import DETECTORS
from mmdet3d.registry import MODELS
from .single_stage import SingleStage3DDetector
@DETECTORS.register_module()
@MODELS.register_module()
class VoteNet(SingleStage3DDetector):
r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection."""
......
......@@ -5,12 +5,11 @@ from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from .. import builder
from ..builder import DETECTORS
from mmdet3d.registry import MODELS
from .single_stage import SingleStage3DDetector
@DETECTORS.register_module()
@MODELS.register_module()
class VoxelNet(SingleStage3DDetector):
r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
......@@ -34,8 +33,8 @@ class VoxelNet(SingleStage3DDetector):
init_cfg=init_cfg,
pretrained=pretrained)
self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder)
self.voxel_encoder = MODELS.build(voxel_encoder)
self.middle_encoder = MODELS.build(middle_encoder)
def extract_feat(self, points, img_metas=None):
"""Extract features from points."""
......
......@@ -7,7 +7,7 @@ from torch.nn import functional as F
from mmdet3d.core.bbox.structures import (get_proj_mat_by_coord_type,
points_cam2img)
from ..builder import FUSION_LAYERS
from mmdet3d.registry import MODELS
from . import apply_3d_transformation
......@@ -91,7 +91,7 @@ def point_sample(img_meta,
return point_features.squeeze().t()
@FUSION_LAYERS.register_module()
@MODELS.register_module()
class PointFusion(BaseModule):
"""Fuse image features from multi-scale features.
......
......@@ -3,13 +3,13 @@ import torch
from torch import nn as nn
from mmdet3d.core.bbox import points_cam2img
from ..builder import FUSION_LAYERS
from mmdet3d.registry import MODELS
from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform
EPS = 1e-6
@FUSION_LAYERS.register_module()
@MODELS.register_module()
class VoteFusion(nn.Module):
"""Fuse 2d features from 3d seeds.
......
......@@ -2,9 +2,9 @@
import torch
from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D
from ..builder import LOSSES
@weighted_loss
......@@ -26,7 +26,7 @@ def axis_aligned_iou_loss(pred, target):
return iou_loss
@LOSSES.register_module()
@MODELS.register_module()
class AxisAlignedIoULoss(nn.Module):
"""Calculate the IoU loss (1-IoU) of axis aligned bounding boxes.
......
......@@ -3,7 +3,7 @@ import torch
from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from ..builder import LOSSES
from mmdet3d.registry import MODELS
def chamfer_distance(src,
......@@ -72,7 +72,7 @@ def chamfer_distance(src,
return loss_src, loss_dst, indices1, indices2
@LOSSES.register_module()
@MODELS.register_module()
class ChamferDistance(nn.Module):
"""Calculate Chamfer Distance of two sets.
......
......@@ -3,8 +3,8 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES
@weighted_loss
......@@ -57,7 +57,7 @@ def multibin_loss(pred_orientations, gt_orientations, num_dir_bins=4):
return cls_losses / num_dir_bins + reg_losses / reg_cnt
@LOSSES.register_module()
@MODELS.register_module()
class MultiBinLoss(nn.Module):
"""Multi-Bin Loss for orientation.
......
......@@ -3,8 +3,8 @@ import torch
from torch import nn as nn
from mmdet3d.ops import PAConv, PAConvCUDA
from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weight_reduce_loss
from ..builder import LOSSES
def weight_correlation(conv):
......@@ -68,7 +68,7 @@ def paconv_regularization_loss(modules, reduction):
return corr_loss
@LOSSES.register_module()
@MODELS.register_module()
class PAConvRegularizationLoss(nn.Module):
"""Calculate correlation loss of kernel weights in PAConv's weight bank.
......
......@@ -2,8 +2,8 @@
import torch
from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES
@weighted_loss
......@@ -58,7 +58,7 @@ def uncertain_l1_loss(pred, target, sigma, alpha=1.0):
return loss
@LOSSES.register_module()
@MODELS.register_module()
class UncertainSmoothL1Loss(nn.Module):
r"""Smooth L1 loss with uncertainty.
......@@ -122,7 +122,7 @@ class UncertainSmoothL1Loss(nn.Module):
return loss_bbox
@LOSSES.register_module()
@MODELS.register_module()
class UncertainL1Loss(nn.Module):
"""L1 loss with uncertainty.
......
......@@ -3,10 +3,10 @@ import torch
from mmcv.runner import auto_fp16
from torch import nn
from ..builder import MIDDLE_ENCODERS
from mmdet3d.registry import MODELS
@MIDDLE_ENCODERS.register_module()
@MODELS.register_module()
class PointPillarsScatter(nn.Module):
"""Point Pillar's Scatter.
......
......@@ -6,8 +6,7 @@ from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE
from mmdet.models.losses import sigmoid_focal_loss, smooth_l1_loss
from ..builder import MIDDLE_ENCODERS
from mmdet3d.registry import MODELS
if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import SparseConvTensor, SparseSequential
......@@ -15,7 +14,7 @@ else:
from mmcv.ops import SparseConvTensor, SparseSequential
@MIDDLE_ENCODERS.register_module()
@MODELS.register_module()
class SparseEncoder(nn.Module):
r"""Sparse encoder for SECOND and Part-A2.
......
......@@ -15,7 +15,7 @@ from mmdet3d.ops.sparse_block import replace_feature
from ..builder import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module()
@MODELS.register_module()
class SparseUNet(BaseModule):
r"""SparseUNet for PartA^2.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment