Commit e3b5253b authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

Update all registries and fix some ut problems

parent 8dd8da12
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE
if IS_SPCONV2_AVAILABLE:
from spconv.pytorch import SparseConvTensor, SparseSequential
else:
from mmcv.ops import SparseConvTensor, SparseSequential
from mmcv.runner import BaseModule, auto_fp16
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from ..builder import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module()
class SparseUNet(BaseModule):
r"""SparseUNet for PartA^2.
See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
Args:
in_channels (int): The number of input channels.
sparse_shape (list[int]): The sparse shape of input tensor.
norm_cfg (dict): Config of normalization layer.
base_channels (int): Out channels for conv_input layer.
output_channels (int): Out channels for conv_out layer.
encoder_channels (tuple[tuple[int]]):
Convolutional channels of each encode block.
encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
decoder_channels (tuple[tuple[int]]):
Convolutional channels of each decode block.
decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
"""
def __init__(self,
in_channels,
sparse_shape,
order=('conv', 'norm', 'act'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels=16,
output_channels=128,
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
(16, 16, 16)),
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.sparse_shape = sparse_shape
self.in_channels = in_channels
self.order = order
self.base_channels = base_channels
self.output_channels = output_channels
self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings
self.decoder_channels = decoder_channels
self.decoder_paddings = decoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own
assert isinstance(order, tuple) and len(order) == 3
assert set(order) == {'conv', 'norm', 'act'}
if self.order[0] != 'conv': # pre activate
self.conv_input = make_sparse_convmodule(
in_channels,
self.base_channels,
3,
norm_cfg=norm_cfg,
padding=1,
indice_key='subm1',
conv_type='SubMConv3d',
order=('conv', ))
else: # post activate
self.conv_input = make_sparse_convmodule(
in_channels,
self.base_channels,
3,
norm_cfg=norm_cfg,
padding=1,
indice_key='subm1',
conv_type='SubMConv3d')
encoder_out_channels = self.make_encoder_layers(
make_sparse_convmodule, norm_cfg, self.base_channels)
self.make_decoder_layers(make_sparse_convmodule, norm_cfg,
encoder_out_channels)
self.conv_out = make_sparse_convmodule(
encoder_out_channels,
self.output_channels,
kernel_size=(3, 1, 1),
stride=(2, 1, 1),
norm_cfg=norm_cfg,
padding=0,
indice_key='spconv_down2',
conv_type='SparseConv3d')
@auto_fp16(apply_to=('voxel_features', ))
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUNet.
Args:
voxel_features (torch.float32): Voxel features in shape [N, C].
coors (torch.int32): Coordinates in shape [N, 4],
the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
batch_size (int): Batch size.
Returns:
dict[str, torch.Tensor]: Backbone features.
"""
coors = coors.int()
input_sp_tensor = SparseConvTensor(voxel_features, coors,
self.sparse_shape, batch_size)
x = self.conv_input(input_sp_tensor)
encode_features = []
for encoder_layer in self.encoder_layers:
x = encoder_layer(x)
encode_features.append(x)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(encode_features[-1])
spatial_features = out.dense()
N, C, D, H, W = spatial_features.shape
spatial_features = spatial_features.view(N, C * D, H, W)
# for segmentation head, with output shape:
# [400, 352, 11] <- [200, 176, 5]
# [800, 704, 21] <- [400, 352, 11]
# [1600, 1408, 41] <- [800, 704, 21]
# [1600, 1408, 41] <- [1600, 1408, 41]
decode_features = []
x = encode_features[-1]
for i in range(self.stage_num, 0, -1):
x = self.decoder_layer_forward(encode_features[i - 1], x,
getattr(self, f'lateral_layer{i}'),
getattr(self, f'merge_layer{i}'),
getattr(self, f'upsample_layer{i}'))
decode_features.append(x)
seg_features = decode_features[-1].features
ret = dict(
spatial_features=spatial_features, seg_features=seg_features)
return ret
def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
merge_layer, upsample_layer):
"""Forward of upsample and residual block.
Args:
x_lateral (:obj:`SparseConvTensor`): Lateral tensor.
x_bottom (:obj:`SparseConvTensor`): Feature from bottom layer.
lateral_layer (SparseBasicBlock): Convolution for lateral tensor.
merge_layer (SparseSequential): Convolution for merging features.
upsample_layer (SparseSequential): Convolution for upsampling.
Returns:
:obj:`SparseConvTensor`: Upsampled feature.
"""
x = lateral_layer(x_lateral)
if IS_SPCONV2_AVAILABLE:
# TODO: try to clean this since it is for the compatibility
# between spconv 1.x & 2
x = x.replace_feature(
torch.cat((x_bottom.features, x.features), dim=1))
else:
x.features = torch.cat((x_bottom.features, x.features), dim=1)
x_merge = merge_layer(x)
x = self.reduce_channel(x, x_merge.features.shape[1])
if IS_SPCONV2_AVAILABLE:
# TODO: try to clean this since it is for the compatibility
# between spconv 1.x & 2
x = x.replace_feature(x_merge.features + x.features)
else:
x.features = x_merge.features + x.features
x = upsample_layer(x)
return x
@staticmethod
def reduce_channel(x, out_channels):
"""reduce channel for element-wise addition.
Args:
x (:obj:`SparseConvTensor`): Sparse tensor, ``x.features``
are in shape (N, C1).
out_channels (int): The number of channel after reduction.
Returns:
:obj:`SparseConvTensor`: Channel reduced feature.
"""
features = x.features
n, in_channels = features.shape
assert (in_channels % out_channels
== 0) and (in_channels >= out_channels)
x.features = features.view(n, out_channels, -1).sum(dim=2)
return x
def make_encoder_layers(self, make_block, norm_cfg, in_channels):
"""make encoder layers using sparse convs.
Args:
make_block (method): A bounded function to build blocks.
norm_cfg (dict[str]): Config of normalization layer.
in_channels (int): The number of encoder input channels.
Returns:
int: The number of encoder output channels.
"""
self.encoder_layers = SparseSequential()
for i, blocks in enumerate(self.encoder_channels):
blocks_list = []
for j, out_channels in enumerate(tuple(blocks)):
padding = tuple(self.encoder_paddings[i])[j]
# each stage started with a spconv layer
# except the first stage
if i != 0 and j == 0:
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
stride=2,
padding=padding,
indice_key=f'spconv{i + 1}',
conv_type='SparseConv3d'))
else:
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
padding=padding,
indice_key=f'subm{i + 1}',
conv_type='SubMConv3d'))
in_channels = out_channels
stage_name = f'encoder_layer{i + 1}'
stage_layers = SparseSequential(*blocks_list)
self.encoder_layers.add_module(stage_name, stage_layers)
return out_channels
def make_decoder_layers(self, make_block, norm_cfg, in_channels):
"""make decoder layers using sparse convs.
Args:
make_block (method): A bounded function to build blocks.
norm_cfg (dict[str]): Config of normalization layer.
in_channels (int): The number of encoder input channels.
Returns:
int: The number of encoder output channels.
"""
block_num = len(self.decoder_channels)
for i, block_channels in enumerate(self.decoder_channels):
paddings = self.decoder_paddings[i]
setattr(
self, f'lateral_layer{block_num - i}',
SparseBasicBlock(
in_channels,
block_channels[0],
conv_cfg=dict(
type='SubMConv3d', indice_key=f'subm{block_num - i}'),
norm_cfg=norm_cfg))
setattr(
self, f'merge_layer{block_num - i}',
make_block(
in_channels * 2,
block_channels[1],
3,
norm_cfg=norm_cfg,
padding=paddings[0],
indice_key=f'subm{block_num - i}',
conv_type='SubMConv3d'))
if block_num - i != 1:
setattr(
self, f'upsample_layer{block_num - i}',
make_block(
in_channels,
block_channels[2],
3,
norm_cfg=norm_cfg,
indice_key=f'spconv{block_num - i}',
conv_type='SparseInverseConv3d'))
else:
# use submanifold conv instead of inverse conv
# in the last block
setattr(
self, f'upsample_layer{block_num - i}',
make_block(
in_channels,
block_channels[2],
3,
norm_cfg=norm_cfg,
padding=paddings[1],
indice_key='subm1',
conv_type='SubMConv3d'))
in_channels = block_channels[2]
...@@ -6,7 +6,7 @@ from mmcv.cnn import ConvModule, build_conv_layer ...@@ -6,7 +6,7 @@ from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from ..builder import NECKS from mmdet3d.registry import MODELS
def fill_up_weights(up): def fill_up_weights(up):
...@@ -167,7 +167,7 @@ class DLAUpsample(BaseModule): ...@@ -167,7 +167,7 @@ class DLAUpsample(BaseModule):
return outs return outs
@NECKS.register_module() @MODELS.register_module()
class DLANeck(BaseModule): class DLANeck(BaseModule):
"""DLA Neck. """DLA Neck.
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn from torch import nn
from ..builder import NECKS from mmdet3d.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class OutdoorImVoxelNeck(nn.Module): class OutdoorImVoxelNeck(nn.Module):
"""Neck for ImVoxelNet outdoor scenario. """Neck for ImVoxelNet outdoor scenario.
......
...@@ -3,10 +3,10 @@ from mmcv.runner import BaseModule ...@@ -3,10 +3,10 @@ from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule from mmdet3d.ops import PointFPModule
from ..builder import NECKS from mmdet3d.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class PointNetFPNeck(BaseModule): class PointNetFPNeck(BaseModule):
r"""PointNet FP Module used in PointRCNN. r"""PointNet FP Module used in PointRCNN.
......
...@@ -5,10 +5,10 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer ...@@ -5,10 +5,10 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
from ..builder import NECKS from mmdet3d.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class SECONDFPN(BaseModule): class SECONDFPN(BaseModule):
"""FPN used in SECOND/PointPillars/PartA2/MVXNet. """FPN used in SECOND/PointPillars/PartA2/MVXNet.
......
...@@ -5,15 +5,17 @@ from mmcv.runner import BaseModule ...@@ -5,15 +5,17 @@ from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import build_bbox_coder
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import HEADS, build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply from mmdet3d.registry import MODELS
from mmdet.core import multi_apply
@HEADS.register_module() @MODELS.register_module()
class H3DBboxHead(BaseModule): class H3DBboxHead(BaseModule):
r"""Bbox head of `H3DNet <https://arxiv.org/abs/2006.05682>`_. r"""Bbox head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.
......
...@@ -14,15 +14,17 @@ else: ...@@ -14,15 +14,17 @@ else:
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import build_bbox_coder
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes, from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr) rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import HEADS, build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.ops import make_sparse_convmodule from mmdet3d.ops import make_sparse_convmodule
from mmdet.core import build_bbox_coder, multi_apply from mmdet3d.registry import MODELS
from mmdet.core import multi_apply
@HEADS.register_module() @MODELS.register_module()
class PartA2BboxHead(BaseModule): class PartA2BboxHead(BaseModule):
"""PartA2 RoI head. """PartA2 RoI head.
......
...@@ -6,15 +6,17 @@ from mmcv.cnn.bricks import build_conv_layer ...@@ -6,15 +6,17 @@ from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import build_bbox_coder
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes, from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr) rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import HEADS, build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply from mmdet3d.registry import MODELS
from mmdet.core import multi_apply
@HEADS.register_module() @MODELS.register_module()
class PointRCNNBboxHead(BaseModule): class PointRCNNBboxHead(BaseModule):
"""PointRCNN RoI Bbox head. """PointRCNN RoI Bbox head.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.core.bbox import bbox3d2result from mmdet3d.core.bbox import bbox3d2result
from ..builder import HEADS, build_head from mmdet3d.registry import MODELS
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
@HEADS.register_module() @MODELS.register_module()
class H3DRoIHead(Base3DRoIHead): class H3DRoIHead(Base3DRoIHead):
"""H3D roi head for H3DNet. """H3D roi head for H3DNet.
...@@ -30,9 +30,9 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -30,9 +30,9 @@ class H3DRoIHead(Base3DRoIHead):
init_cfg=init_cfg) init_cfg=init_cfg)
# Primitive module # Primitive module
assert len(primitive_list) == 3 assert len(primitive_list) == 3
self.primitive_z = build_head(primitive_list[0]) self.primitive_z = MODELS.build(primitive_list[0])
self.primitive_xy = build_head(primitive_list[1]) self.primitive_xy = MODELS.build(primitive_list[1])
self.primitive_line = build_head(primitive_list[2]) self.primitive_line = MODELS.build(primitive_list[2])
def init_mask_head(self): def init_mask_head(self):
"""Initialize mask head, skip since ``H3DROIHead`` does not have """Initialize mask head, skip since ``H3DROIHead`` does not have
...@@ -43,7 +43,7 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -43,7 +43,7 @@ class H3DRoIHead(Base3DRoIHead):
"""Initialize box head.""" """Initialize box head."""
bbox_head['train_cfg'] = self.train_cfg bbox_head['train_cfg'] = self.train_cfg
bbox_head['test_cfg'] = self.test_cfg bbox_head['test_cfg'] = self.test_cfg
self.bbox_head = build_head(bbox_head) self.bbox_head = MODELS.build(bbox_head)
def init_assigner_sampler(self): def init_assigner_sampler(self):
"""Initialize assigner and sampler.""" """Initialize assigner and sampler."""
......
...@@ -5,11 +5,12 @@ from torch import nn as nn ...@@ -5,11 +5,12 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox.structures import rotation_3d_in_axis from mmdet3d.core.bbox.structures import rotation_3d_in_axis
from mmdet3d.models.builder import HEADS, build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
@HEADS.register_module() @MODELS.register_module()
class PointwiseSemanticHead(BaseModule): class PointwiseSemanticHead(BaseModule):
"""Semantic segmentation head for point-wise segmentation. """Semantic segmentation head for point-wise segmentation.
......
...@@ -6,13 +6,14 @@ from mmcv.runner import BaseModule ...@@ -6,13 +6,14 @@ from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.models.builder import HEADS, build_loss from mmdet3d.models.builder import build_loss
from mmdet3d.models.model_utils import VoteModule from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
@HEADS.register_module() @MODELS.register_module()
class PrimitiveHead(BaseModule): class PrimitiveHead(BaseModule):
r"""Primitive head of `H3DNet <https://arxiv.org/abs/2006.05682>`_. r"""Primitive head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.
......
...@@ -3,14 +3,13 @@ import warnings ...@@ -3,14 +3,13 @@ import warnings
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import AssignResult from mmdet3d.core import AssignResult, build_assigner, build_sampler
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
from mmdet.core import build_assigner, build_sampler from mmdet3d.registry import MODELS
from ..builder import HEADS, build_head, build_roi_extractor
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
@HEADS.register_module() @MODELS.register_module()
class PartAggregationROIHead(Base3DRoIHead): class PartAggregationROIHead(Base3DRoIHead):
"""Part aggregation roi head for PartA2. """Part aggregation roi head for PartA2.
...@@ -41,12 +40,12 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -41,12 +40,12 @@ class PartAggregationROIHead(Base3DRoIHead):
init_cfg=init_cfg) init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
assert semantic_head is not None assert semantic_head is not None
self.semantic_head = build_head(semantic_head) self.semantic_head = MODELS.build(semantic_head)
if seg_roi_extractor is not None: if seg_roi_extractor is not None:
self.seg_roi_extractor = build_roi_extractor(seg_roi_extractor) self.seg_roi_extractor = MODELS.build(seg_roi_extractor)
if part_roi_extractor is not None: if part_roi_extractor is not None:
self.part_roi_extractor = build_roi_extractor(part_roi_extractor) self.part_roi_extractor = MODELS.build(part_roi_extractor)
self.init_assigner_sampler() self.init_assigner_sampler()
...@@ -64,7 +63,7 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -64,7 +63,7 @@ class PartAggregationROIHead(Base3DRoIHead):
def init_bbox_head(self, bbox_head): def init_bbox_head(self, bbox_head):
"""Initialize box head.""" """Initialize box head."""
self.bbox_head = build_head(bbox_head) self.bbox_head = MODELS.build(bbox_head)
def init_assigner_sampler(self): def init_assigner_sampler(self):
"""Initialize assigner and sampler.""" """Initialize assigner and sampler."""
......
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import AssignResult from mmdet3d.core import AssignResult, build_assigner, build_sampler
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
from mmdet.core import build_assigner, build_sampler from mmdet3d.registry import MODELS
from ..builder import HEADS, build_head, build_roi_extractor
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
@HEADS.register_module() @MODELS.register_module()
class PointRCNNRoIHead(Base3DRoIHead): class PointRCNNRoIHead(Base3DRoIHead):
"""RoI head for PointRCNN. """RoI head for PointRCNN.
...@@ -40,7 +39,7 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -40,7 +39,7 @@ class PointRCNNRoIHead(Base3DRoIHead):
self.depth_normalizer = depth_normalizer self.depth_normalizer = depth_normalizer
if point_roi_extractor is not None: if point_roi_extractor is not None:
self.point_roi_extractor = build_roi_extractor(point_roi_extractor) self.point_roi_extractor = MODELS.build(point_roi_extractor)
self.init_assigner_sampler() self.init_assigner_sampler()
...@@ -50,7 +49,7 @@ class PointRCNNRoIHead(Base3DRoIHead): ...@@ -50,7 +49,7 @@ class PointRCNNRoIHead(Base3DRoIHead):
Args: Args:
bbox_head (dict): Config dict of RoI Head. bbox_head (dict): Config dict of RoI Head.
""" """
self.bbox_head = build_head(bbox_head) self.bbox_head = MODELS.build(bbox_head)
def init_mask_head(self): def init_mask_head(self):
"""Initialize maek head.""" """Initialize maek head."""
......
...@@ -3,10 +3,10 @@ import torch ...@@ -3,10 +3,10 @@ import torch
from mmcv import ops from mmcv import ops
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmdet3d.models.builder import ROI_EXTRACTORS from mmdet3d.registry import MODELS
@ROI_EXTRACTORS.register_module() @MODELS.register_module()
class Single3DRoIAwareExtractor(BaseModule): class Single3DRoIAwareExtractor(BaseModule):
"""Point-wise roi-aware Extractor. """Point-wise roi-aware Extractor.
......
...@@ -4,10 +4,10 @@ from mmcv import ops ...@@ -4,10 +4,10 @@ from mmcv import ops
from torch import nn as nn from torch import nn as nn
from mmdet3d.core.bbox.structures import rotation_3d_in_axis from mmdet3d.core.bbox.structures import rotation_3d_in_axis
from mmdet3d.models.builder import ROI_EXTRACTORS from mmdet3d.registry import MODELS
@ROI_EXTRACTORS.register_module() @MODELS.register_module()
class Single3DRoIPointExtractor(nn.Module): class Single3DRoIPointExtractor(nn.Module):
"""Point-wise roi-aware Extractor. """Point-wise roi-aware Extractor.
......
...@@ -4,13 +4,12 @@ import torch ...@@ -4,13 +4,12 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.registry import MODELS
from mmseg.core import add_prefix from mmseg.core import add_prefix
from ..builder import (SEGMENTORS, build_backbone, build_head, build_loss,
build_neck)
from .base import Base3DSegmentor from .base import Base3DSegmentor
@SEGMENTORS.register_module() @MODELS.register_module()
class EncoderDecoder3D(Base3DSegmentor): class EncoderDecoder3D(Base3DSegmentor):
"""3D Encoder Decoder segmentors. """3D Encoder Decoder segmentors.
...@@ -30,9 +29,9 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -30,9 +29,9 @@ class EncoderDecoder3D(Base3DSegmentor):
pretrained=None, pretrained=None,
init_cfg=None): init_cfg=None):
super(EncoderDecoder3D, self).__init__(init_cfg=init_cfg) super(EncoderDecoder3D, self).__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone) self.backbone = MODELS.build(backbone)
if neck is not None: if neck is not None:
self.neck = build_neck(neck) self.neck = MODELS.build(neck)
self._init_decode_head(decode_head) self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head) self._init_auxiliary_head(auxiliary_head)
self._init_loss_regularization(loss_regularization) self._init_loss_regularization(loss_regularization)
...@@ -44,7 +43,7 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -44,7 +43,7 @@ class EncoderDecoder3D(Base3DSegmentor):
def _init_decode_head(self, decode_head): def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``""" """Initialize ``decode_head``"""
self.decode_head = build_head(decode_head) self.decode_head = MODELS.build(decode_head)
self.num_classes = self.decode_head.num_classes self.num_classes = self.decode_head.num_classes
def _init_auxiliary_head(self, auxiliary_head): def _init_auxiliary_head(self, auxiliary_head):
...@@ -53,9 +52,9 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -53,9 +52,9 @@ class EncoderDecoder3D(Base3DSegmentor):
if isinstance(auxiliary_head, list): if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList() self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head: for head_cfg in auxiliary_head:
self.auxiliary_head.append(build_head(head_cfg)) self.auxiliary_head.append(MODELS.build(head_cfg))
else: else:
self.auxiliary_head = build_head(auxiliary_head) self.auxiliary_head = MODELS.build(auxiliary_head)
def _init_loss_regularization(self, loss_regularization): def _init_loss_regularization(self, loss_regularization):
"""Initialize ``loss_regularization``""" """Initialize ``loss_regularization``"""
...@@ -63,9 +62,9 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -63,9 +62,9 @@ class EncoderDecoder3D(Base3DSegmentor):
if isinstance(loss_regularization, list): if isinstance(loss_regularization, list):
self.loss_regularization = nn.ModuleList() self.loss_regularization = nn.ModuleList()
for loss_cfg in loss_regularization: for loss_cfg in loss_regularization:
self.loss_regularization.append(build_loss(loss_cfg)) self.loss_regularization.append(MODELS.build(loss_cfg))
else: else:
self.loss_regularization = build_loss(loss_regularization) self.loss_regularization = MODELS.build(loss_regularization)
def extract_feat(self, points): def extract_feat(self, points):
"""Extract features from points.""" """Extract features from points."""
......
...@@ -5,11 +5,11 @@ from mmcv.ops import DynamicScatter ...@@ -5,11 +5,11 @@ from mmcv.ops import DynamicScatter
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch import nn from torch import nn
from ..builder import VOXEL_ENCODERS from mmdet3d.registry import MODELS
from .utils import PFNLayer, get_paddings_indicator from .utils import PFNLayer, get_paddings_indicator
@VOXEL_ENCODERS.register_module() @MODELS.register_module()
class PillarFeatureNet(nn.Module): class PillarFeatureNet(nn.Module):
"""Pillar Feature Net. """Pillar Feature Net.
...@@ -159,7 +159,7 @@ class PillarFeatureNet(nn.Module): ...@@ -159,7 +159,7 @@ class PillarFeatureNet(nn.Module):
return features.squeeze(1) return features.squeeze(1)
@VOXEL_ENCODERS.register_module() @MODELS.register_module()
class DynamicPillarFeatureNet(PillarFeatureNet): class DynamicPillarFeatureNet(PillarFeatureNet):
"""Pillar Feature Net using dynamic voxelization. """Pillar Feature Net using dynamic voxelization.
......
...@@ -5,12 +5,12 @@ from mmcv.ops import DynamicScatter ...@@ -5,12 +5,12 @@ from mmcv.ops import DynamicScatter
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch import nn from torch import nn
from mmdet3d.registry import MODELS
from .. import builder from .. import builder
from ..builder import VOXEL_ENCODERS
from .utils import VFELayer, get_paddings_indicator from .utils import VFELayer, get_paddings_indicator
@VOXEL_ENCODERS.register_module() @MODELS.register_module()
class HardSimpleVFE(nn.Module): class HardSimpleVFE(nn.Module):
"""Simple voxel feature encoder used in SECOND. """Simple voxel feature encoder used in SECOND.
...@@ -45,7 +45,7 @@ class HardSimpleVFE(nn.Module): ...@@ -45,7 +45,7 @@ class HardSimpleVFE(nn.Module):
return points_mean.contiguous() return points_mean.contiguous()
@VOXEL_ENCODERS.register_module() @MODELS.register_module()
class DynamicSimpleVFE(nn.Module): class DynamicSimpleVFE(nn.Module):
"""Simple dynamic voxel feature encoder used in DV-SECOND. """Simple dynamic voxel feature encoder used in DV-SECOND.
...@@ -84,7 +84,7 @@ class DynamicSimpleVFE(nn.Module): ...@@ -84,7 +84,7 @@ class DynamicSimpleVFE(nn.Module):
return features, features_coors return features, features_coors
@VOXEL_ENCODERS.register_module() @MODELS.register_module()
class DynamicVFE(nn.Module): class DynamicVFE(nn.Module):
"""Dynamic Voxel feature encoder used in DV-SECOND. """Dynamic Voxel feature encoder used in DV-SECOND.
...@@ -286,7 +286,7 @@ class DynamicVFE(nn.Module): ...@@ -286,7 +286,7 @@ class DynamicVFE(nn.Module):
return voxel_feats, voxel_coors return voxel_feats, voxel_coors
@VOXEL_ENCODERS.register_module() @MODELS.register_module()
class HardVFE(nn.Module): class HardVFE(nn.Module):
"""Voxel feature encoder used in DV-SECOND. """Voxel feature encoder used in DV-SECOND.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .registry import OBJECTSAMPLERS, TRANSFORMS from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS,
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS,
VISUALIZERS, WEIGHT_INITIALIZERS)
__all__ = ['TRANSFORMS', 'OBJECTSAMPLERS'] __all__ = [
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS'
]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
"""MMDetection3D provides 17 registry nodes to support using modules across
projects. Each node is a child of the root registry in MMEngine.
More details can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import HOOKS as MMENGINE_HOOKS
from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import METRICS as MMENGINE_METRICS
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import \
OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
from mmengine.registry import \
RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
from mmengine.registry import \
WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS
from mmengine.registry import Registry from mmengine.registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS) TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
OBJECTSAMPLERS = Registry('Object sampler')
# mangage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
# mangage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
# manage constructors that customize the optimization hyperparameters.
OPTIMIZER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment