Commit e3b5253b authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

Update all registries and fix some ut problems

parent 8dd8da12
...@@ -10,12 +10,13 @@ from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius, ...@@ -10,12 +10,13 @@ from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
xywhr2xyxyr) xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev from mmdet3d.core.post_processing import nms_bev
from mmdet3d.models import builder from mmdet3d.models import builder
from mmdet3d.models.builder import build_loss
from mmdet3d.models.utils import clip_sigmoid from mmdet3d.models.utils import clip_sigmoid
from mmdet3d.registry import MODELS
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from ..builder import HEADS, build_loss
@HEADS.register_module() @MODELS.register_module()
class SeparateHead(BaseModule): class SeparateHead(BaseModule):
"""SeparateHead for CenterHead. """SeparateHead for CenterHead.
...@@ -121,7 +122,7 @@ class SeparateHead(BaseModule): ...@@ -121,7 +122,7 @@ class SeparateHead(BaseModule):
return ret_dict return ret_dict
@HEADS.register_module() @MODELS.register_module()
class DCNSeparateHead(BaseModule): class DCNSeparateHead(BaseModule):
r"""DCNSeparateHead for CenterHead. r"""DCNSeparateHead for CenterHead.
...@@ -240,7 +241,7 @@ class DCNSeparateHead(BaseModule): ...@@ -240,7 +241,7 @@ class DCNSeparateHead(BaseModule):
return ret return ret
@HEADS.register_module() @MODELS.register_module()
class CenterHead(BaseModule): class CenterHead(BaseModule):
"""CenterHead for CenterPoint. """CenterHead for CenterPoint.
......
...@@ -9,15 +9,16 @@ from torch import nn as nn ...@@ -9,15 +9,16 @@ from torch import nn as nn
from mmdet3d.core import (box3d_multiclass_nms, limit_period, points_img2cam, from mmdet3d.core import (box3d_multiclass_nms, limit_period, points_img2cam,
xywhr2xyxyr) xywhr2xyxyr)
from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from mmdet.core.bbox.builder import build_bbox_coder from mmdet.core.bbox.builder import build_bbox_coder
from ..builder import HEADS, build_loss
from .anchor_free_mono3d_head import AnchorFreeMono3DHead from .anchor_free_mono3d_head import AnchorFreeMono3DHead
INF = 1e8 INF = 1e8
@HEADS.register_module() @MODELS.register_module()
class FCOSMono3DHead(AnchorFreeMono3DHead): class FCOSMono3DHead(AnchorFreeMono3DHead):
"""Anchor-free head used in FCOS3D. """Anchor-free head used in FCOS3D.
......
...@@ -4,12 +4,12 @@ from mmcv.runner import force_fp32 ...@@ -4,12 +4,12 @@ from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox import bbox_overlaps_nearest_3d from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
from ..builder import HEADS from mmdet3d.registry import MODELS
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
from .train_mixins import get_direction_target from .train_mixins import get_direction_target
@HEADS.register_module() @MODELS.register_module()
class FreeAnchor3DHead(Anchor3DHead): class FreeAnchor3DHead(Anchor3DHead):
r"""`FreeAnchor <https://arxiv.org/abs/1909.02466>`_ head for 3D detection. r"""`FreeAnchor <https://arxiv.org/abs/1909.02466>`_ head for 3D detection.
......
...@@ -14,8 +14,9 @@ from torch import nn as nn ...@@ -14,8 +14,9 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.registry import MODELS
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from ..builder import HEADS, build_loss from ..builder import build_loss
from .base_conv_bbox_head import BaseConvBboxHead from .base_conv_bbox_head import BaseConvBboxHead
EPS = 1e-6 EPS = 1e-6
...@@ -106,7 +107,7 @@ class GeneralSamplingModule(nn.Module): ...@@ -106,7 +107,7 @@ class GeneralSamplingModule(nn.Module):
return new_xyz, new_features, sample_inds return new_xyz, new_features, sample_inds
@HEADS.register_module() @MODELS.register_module()
class GroupFree3DHead(BaseModule): class GroupFree3DHead(BaseModule):
r"""Bbox head of `Group-Free 3D <https://arxiv.org/abs/2104.00678>`_. r"""Bbox head of `Group-Free 3D <https://arxiv.org/abs/2104.00678>`_.
......
...@@ -3,21 +3,22 @@ import torch ...@@ -3,21 +3,22 @@ import torch
from mmcv.cnn import xavier_init from mmcv.cnn import xavier_init
from torch import nn as nn from torch import nn as nn
from mmdet3d.core.bbox.builder import build_bbox_coder
from mmdet3d.core.utils import get_ellip_gaussian_2D from mmdet3d.core.utils import get_ellip_gaussian_2D
from mmdet3d.models.builder import build_loss
from mmdet3d.models.model_utils import EdgeFusionModule from mmdet3d.models.model_utils import EdgeFusionModule
from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices, from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices,
get_keypoints, handle_proj_objs) get_keypoints, handle_proj_objs)
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from mmdet.core.bbox.builder import build_bbox_coder
from mmdet.models.utils import gaussian_radius, gen_gaussian_target from mmdet.models.utils import gaussian_radius, gen_gaussian_target
from mmdet.models.utils.gaussian_target import (get_local_maximum, from mmdet.models.utils.gaussian_target import (get_local_maximum,
get_topk_from_heatmap, get_topk_from_heatmap,
transpose_and_gather_feat) transpose_and_gather_feat)
from ..builder import HEADS, build_loss
from .anchor_free_mono3d_head import AnchorFreeMono3DHead from .anchor_free_mono3d_head import AnchorFreeMono3DHead
@HEADS.register_module() @MODELS.register_module()
class MonoFlexHead(AnchorFreeMono3DHead): class MonoFlexHead(AnchorFreeMono3DHead):
r"""MonoFlex head used in `MonoFlex <https://arxiv.org/abs/2104.02323>`_ r"""MonoFlex head used in `MonoFlex <https://arxiv.org/abs/2104.02323>`_
......
...@@ -5,11 +5,11 @@ from mmcv.runner import force_fp32 ...@@ -5,11 +5,11 @@ from mmcv.runner import force_fp32
from mmdet3d.core import limit_period, xywhr2xyxyr from mmdet3d.core import limit_period, xywhr2xyxyr
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from ..builder import HEADS from mmdet3d.registry import MODELS
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
@HEADS.register_module() @MODELS.register_module()
class PartA2RPNHead(Anchor3DHead): class PartA2RPNHead(Anchor3DHead):
"""RPN head for PartA2. """RPN head for PartA2.
......
...@@ -8,12 +8,13 @@ from torch.nn import functional as F ...@@ -8,12 +8,13 @@ from torch.nn import functional as F
from mmdet3d.core import box3d_multiclass_nms, xywhr2xyxyr from mmdet3d.core import box3d_multiclass_nms, xywhr2xyxyr
from mmdet3d.core.bbox import points_cam2img, points_img2cam from mmdet3d.core.bbox import points_cam2img, points_img2cam
from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS
from mmdet.core import distance2bbox, multi_apply from mmdet.core import distance2bbox, multi_apply
from ..builder import HEADS, build_loss
from .fcos_mono3d_head import FCOSMono3DHead from .fcos_mono3d_head import FCOSMono3DHead
@HEADS.register_module() @MODELS.register_module()
class PGDHead(FCOSMono3DHead): class PGDHead(FCOSMono3DHead):
r"""Anchor-free head used in `PGD <https://arxiv.org/abs/2107.14160>`_. r"""Anchor-free head used in `PGD <https://arxiv.org/abs/2107.14160>`_.
......
...@@ -7,11 +7,12 @@ from mmdet3d.core import xywhr2xyxyr ...@@ -7,11 +7,12 @@ from mmdet3d.core import xywhr2xyxyr
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes) LiDARInstance3DBoxes)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from ..builder import HEADS, build_loss
@HEADS.register_module() @MODELS.register_module()
class PointRPNHead(BaseModule): class PointRPNHead(BaseModule):
"""RPN module for PointRCNN. """RPN module for PointRCNN.
......
...@@ -8,12 +8,13 @@ from mmcv.runner import BaseModule ...@@ -8,12 +8,13 @@ from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from ..builder import HEADS, build_head from ..builder import build_head
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
@HEADS.register_module() @MODELS.register_module()
class BaseShapeHead(BaseModule): class BaseShapeHead(BaseModule):
"""Base Shape-aware Head in Shape Signature Network. """Base Shape-aware Head in Shape Signature Network.
...@@ -164,7 +165,7 @@ class BaseShapeHead(BaseModule): ...@@ -164,7 +165,7 @@ class BaseShapeHead(BaseModule):
return ret return ret
@HEADS.register_module() @MODELS.register_module()
class ShapeAwareHead(Anchor3DHead): class ShapeAwareHead(Anchor3DHead):
"""Shape-aware grouping head for SSN. """Shape-aware grouping head for SSN.
......
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from mmdet.core.bbox.builder import build_bbox_coder from mmdet.core.bbox.builder import build_bbox_coder
from mmdet.models.utils import gaussian_radius, gen_gaussian_target from mmdet.models.utils import gaussian_radius, gen_gaussian_target
from mmdet.models.utils.gaussian_target import (get_local_maximum, from mmdet.models.utils.gaussian_target import (get_local_maximum,
get_topk_from_heatmap, get_topk_from_heatmap,
transpose_and_gather_feat) transpose_and_gather_feat)
from ..builder import HEADS
from .anchor_free_mono3d_head import AnchorFreeMono3DHead from .anchor_free_mono3d_head import AnchorFreeMono3DHead
@HEADS.register_module() @MODELS.register_module()
class SMOKEMono3DHead(AnchorFreeMono3DHead): class SMOKEMono3DHead(AnchorFreeMono3DHead):
r"""Anchor-free head used in `SMOKE <https://arxiv.org/abs/2002.10111>`_ r"""Anchor-free head used in `SMOKE <https://arxiv.org/abs/2002.10111>`_
......
...@@ -7,12 +7,13 @@ from torch.nn import functional as F ...@@ -7,12 +7,13 @@ from torch.nn import functional as F
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
LiDARInstance3DBoxes, LiDARInstance3DBoxes,
rotation_3d_in_axis) rotation_3d_in_axis)
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from ..builder import HEADS, build_loss from ..builder import build_loss
from .vote_head import VoteHead from .vote_head import VoteHead
@HEADS.register_module() @MODELS.register_module()
class SSD3DHead(VoteHead): class SSD3DHead(VoteHead):
r"""Bbox head of `3DSSD <https://arxiv.org/abs/2002.10187>`_. r"""Bbox head of `3DSSD <https://arxiv.org/abs/2002.10187>`_.
......
...@@ -6,15 +6,16 @@ from mmcv.runner import BaseModule, force_fp32 ...@@ -6,15 +6,16 @@ from mmcv.runner import BaseModule, force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.models.model_utils import VoteModule from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from ..builder import HEADS, build_loss
from .base_conv_bbox_head import BaseConvBboxHead from .base_conv_bbox_head import BaseConvBboxHead
@HEADS.register_module() @MODELS.register_module()
class VoteHead(BaseModule): class VoteHead(BaseModule):
r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_. r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_.
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
import torch import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
@DETECTORS.register_module() @MODELS.register_module()
class CenterPoint(MVXTwoStageDetector): class CenterPoint(MVXTwoStageDetector):
"""Base class of Multi-modality VoxelNet.""" """Base class of Multi-modality VoxelNet."""
......
...@@ -3,11 +3,11 @@ import torch ...@@ -3,11 +3,11 @@ import torch
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .voxelnet import VoxelNet from .voxelnet import VoxelNet
@DETECTORS.register_module() @MODELS.register_module()
class DynamicVoxelNet(VoxelNet): class DynamicVoxelNet(VoxelNet):
r"""VoxelNet using `dynamic voxelization <https://arxiv.org/abs/1910.06528>`_. r"""VoxelNet using `dynamic voxelization <https://arxiv.org/abs/1910.06528>`_.
""" """
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .single_stage_mono3d import SingleStageMono3DDetector from .single_stage_mono3d import SingleStageMono3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class FCOSMono3D(SingleStageMono3DDetector): class FCOSMono3D(SingleStageMono3DDetector):
r"""`FCOS3D <https://arxiv.org/abs/2104.10956>`_ for monocular 3D object detection. r"""`FCOS3D <https://arxiv.org/abs/2104.10956>`_ for monocular 3D object detection.
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
import torch import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .single_stage import SingleStage3DDetector from .single_stage import SingleStage3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class GroupFree3DNet(SingleStage3DDetector): class GroupFree3DNet(SingleStage3DDetector):
"""`Group-Free 3D <https://arxiv.org/abs/2104.00678>`_.""" """`Group-Free 3D <https://arxiv.org/abs/2104.00678>`_."""
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
import torch import torch
from mmdet3d.core import merge_aug_bboxes_3d from mmdet3d.core import merge_aug_bboxes_3d
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .two_stage import TwoStage3DDetector from .two_stage import TwoStage3DDetector
@DETECTORS.register_module() @MODELS.register_module()
class H3DNet(TwoStage3DDetector): class H3DNet(TwoStage3DDetector):
r"""H3DNet model. r"""H3DNet model.
......
...@@ -6,8 +6,7 @@ import torch ...@@ -6,8 +6,7 @@ import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.models.utils import MLP from mmdet3d.models.utils import MLP
from .. import builder from mmdet3d.registry import MODELS
from ..builder import DETECTORS
from .base import Base3DDetector from .base import Base3DDetector
...@@ -53,7 +52,7 @@ def sample_valid_seeds(mask, num_sampled_seed=1024): ...@@ -53,7 +52,7 @@ def sample_valid_seeds(mask, num_sampled_seed=1024):
return sample_inds return sample_inds
@DETECTORS.register_module() @MODELS.register_module()
class ImVoteNet(Base3DDetector): class ImVoteNet(Base3DDetector):
r"""`ImVoteNet <https://arxiv.org/abs/2001.10692>`_ for 3D detection.""" r"""`ImVoteNet <https://arxiv.org/abs/2001.10692>`_ for 3D detection."""
...@@ -78,9 +77,9 @@ class ImVoteNet(Base3DDetector): ...@@ -78,9 +77,9 @@ class ImVoteNet(Base3DDetector):
# point branch # point branch
if pts_backbone is not None: if pts_backbone is not None:
self.pts_backbone = builder.build_backbone(pts_backbone) self.pts_backbone = MODELS.build(pts_backbone)
if pts_neck is not None: if pts_neck is not None:
self.pts_neck = builder.build_neck(pts_neck) self.pts_neck = MODELS.build(pts_neck)
if pts_bbox_heads is not None: if pts_bbox_heads is not None:
pts_bbox_head_common = pts_bbox_heads.common pts_bbox_head_common = pts_bbox_heads.common
pts_bbox_head_common.update( pts_bbox_head_common.update(
...@@ -93,9 +92,9 @@ class ImVoteNet(Base3DDetector): ...@@ -93,9 +92,9 @@ class ImVoteNet(Base3DDetector):
pts_bbox_head_img = pts_bbox_head_common.copy() pts_bbox_head_img = pts_bbox_head_common.copy()
pts_bbox_head_img.update(pts_bbox_heads.img) pts_bbox_head_img.update(pts_bbox_heads.img)
self.pts_bbox_head_joint = builder.build_head(pts_bbox_head_joint) self.pts_bbox_head_joint = MODELS.build(pts_bbox_head_joint)
self.pts_bbox_head_pts = builder.build_head(pts_bbox_head_pts) self.pts_bbox_head_pts = MODELS.build(pts_bbox_head_pts)
self.pts_bbox_head_img = builder.build_head(pts_bbox_head_img) self.pts_bbox_head_img = MODELS.build(pts_bbox_head_img)
self.pts_bbox_heads = [ self.pts_bbox_heads = [
self.pts_bbox_head_joint, self.pts_bbox_head_pts, self.pts_bbox_head_joint, self.pts_bbox_head_pts,
self.pts_bbox_head_img self.pts_bbox_head_img
...@@ -104,26 +103,26 @@ class ImVoteNet(Base3DDetector): ...@@ -104,26 +103,26 @@ class ImVoteNet(Base3DDetector):
# image branch # image branch
if img_backbone: if img_backbone:
self.img_backbone = builder.build_backbone(img_backbone) self.img_backbone = MODELS.build(img_backbone)
if img_neck is not None: if img_neck is not None:
self.img_neck = builder.build_neck(img_neck) self.img_neck = MODELS.build(img_neck)
if img_rpn_head is not None: if img_rpn_head is not None:
rpn_train_cfg = train_cfg.img_rpn if train_cfg \ rpn_train_cfg = train_cfg.img_rpn if train_cfg \
is not None else None is not None else None
img_rpn_head_ = img_rpn_head.copy() img_rpn_head_ = img_rpn_head.copy()
img_rpn_head_.update( img_rpn_head_.update(
train_cfg=rpn_train_cfg, test_cfg=test_cfg.img_rpn) train_cfg=rpn_train_cfg, test_cfg=test_cfg.img_rpn)
self.img_rpn_head = builder.build_head(img_rpn_head_) self.img_rpn_head = MODELS.build(img_rpn_head_)
if img_roi_head is not None: if img_roi_head is not None:
rcnn_train_cfg = train_cfg.img_rcnn if train_cfg \ rcnn_train_cfg = train_cfg.img_rcnn if train_cfg \
is not None else None is not None else None
img_roi_head.update( img_roi_head.update(
train_cfg=rcnn_train_cfg, test_cfg=test_cfg.img_rcnn) train_cfg=rcnn_train_cfg, test_cfg=test_cfg.img_rcnn)
self.img_roi_head = builder.build_head(img_roi_head) self.img_roi_head = MODELS.build(img_roi_head)
# fusion # fusion
if fusion_layer is not None: if fusion_layer is not None:
self.fusion_layer = builder.build_fusion_layer(fusion_layer) self.fusion_layer = MODELS.build(fusion_layer)
self.max_imvote_per_pixel = fusion_layer.max_imvote_per_pixel self.max_imvote_per_pixel = fusion_layer.max_imvote_per_pixel
self.freeze_img_branch = freeze_img_branch self.freeze_img_branch = freeze_img_branch
......
...@@ -3,11 +3,11 @@ import torch ...@@ -3,11 +3,11 @@ import torch
from mmdet3d.core import bbox3d2result, build_prior_generator from mmdet3d.core import bbox3d2result, build_prior_generator
from mmdet3d.models.fusion_layers.point_fusion import point_sample from mmdet3d.models.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS
from mmdet.models.detectors import BaseDetector from mmdet.models.detectors import BaseDetector
from ..builder import DETECTORS, build_backbone, build_head, build_neck
@DETECTORS.register_module() @MODELS.register_module()
class ImVoxelNet(BaseDetector): class ImVoxelNet(BaseDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.""" r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_."""
...@@ -23,12 +23,12 @@ class ImVoxelNet(BaseDetector): ...@@ -23,12 +23,12 @@ class ImVoxelNet(BaseDetector):
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone) self.backbone = MODELS.build(backbone)
self.neck = build_neck(neck) self.neck = MODELS.build(neck)
self.neck_3d = build_neck(neck_3d) self.neck_3d = MODELS.build(neck_3d)
bbox_head.update(train_cfg=train_cfg) bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg) bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head) self.bbox_head = MODELS.build(bbox_head)
self.n_voxels = n_voxels self.n_voxels = n_voxels
self.anchor_generator = build_prior_generator(anchor_generator) self.anchor_generator = build_prior_generator(anchor_generator)
self.train_cfg = train_cfg self.train_cfg = train_cfg
......
...@@ -3,11 +3,11 @@ import torch ...@@ -3,11 +3,11 @@ import torch
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from ..builder import DETECTORS from mmdet3d.registry import MODELS
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
@DETECTORS.register_module() @MODELS.register_module()
class MVXFasterRCNN(MVXTwoStageDetector): class MVXFasterRCNN(MVXTwoStageDetector):
"""Multi-modality VoxelNet using Faster R-CNN.""" """Multi-modality VoxelNet using Faster R-CNN."""
...@@ -15,7 +15,7 @@ class MVXFasterRCNN(MVXTwoStageDetector): ...@@ -15,7 +15,7 @@ class MVXFasterRCNN(MVXTwoStageDetector):
super(MVXFasterRCNN, self).__init__(**kwargs) super(MVXFasterRCNN, self).__init__(**kwargs)
@DETECTORS.register_module() @MODELS.register_module()
class DynamicMVXFasterRCNN(MVXTwoStageDetector): class DynamicMVXFasterRCNN(MVXTwoStageDetector):
"""Multi-modality VoxelNet using Faster R-CNN and dynamic voxelization.""" """Multi-modality VoxelNet using Faster R-CNN and dynamic voxelization."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment