from torch import nn as nn from mmdet3d.ops import make_sparse_convmodule from mmdet3d.ops import spconv as spconv from ..registry import MIDDLE_ENCODERS @MIDDLE_ENCODERS.register_module() class SparseEncoder(nn.Module): """Sparse encoder for Second. See https://arxiv.org/abs/1907.03670 for more detials. Args: in_channels (int): the number of input channels sparse_shape (list[int]): the sparse shape of input tensor norm_cfg (dict): config of normalization layer base_channels (int): out channels for conv_input layer output_channels (int): out channels for conv_out layer encoder_channels (tuple[tuple[int]]): conv channels of each encode block encoder_paddings (tuple[tuple[int]]): paddings of each encode block """ def __init__(self, in_channels, sparse_shape, order=('conv', 'norm', 'act'), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), base_channels=16, output_channels=128, encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)), encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1))): super().__init__() self.sparse_shape = sparse_shape self.in_channels = in_channels self.order = order self.base_channels = base_channels self.output_channels = output_channels self.encoder_channels = encoder_channels self.encoder_paddings = encoder_paddings self.stage_num = len(self.encoder_channels) # Spconv init all weight on its own assert isinstance(order, tuple) and len(order) == 3 assert set(order) == {'conv', 'norm', 'act'} if self.order[0] != 'conv': # pre activate self.conv_input = make_sparse_convmodule( in_channels, self.base_channels, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1', conv_type='SubMConv3d', order=('conv', )) else: # post activate self.conv_input = make_sparse_convmodule( in_channels, self.base_channels, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1', conv_type='SubMConv3d') encoder_out_channels = self.make_encoder_layers( make_sparse_convmodule, norm_cfg, self.base_channels) self.conv_out = make_sparse_convmodule( encoder_out_channels, self.output_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), norm_cfg=norm_cfg, padding=0, indice_key='spconv_down2', conv_type='SparseConv3d') def forward(self, voxel_features, coors, batch_size): """Forward of SparseEncoder. Args: voxel_features (torch.float32): shape [N, C] coors (torch.int32): shape [N, 4](batch_idx, z_idx, y_idx, x_idx) batch_size (int): batch size Returns: dict: backbone features """ coors = coors.int() input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size) x = self.conv_input(input_sp_tensor) encode_features = [] for encoder_layer in self.encoder_layers: x = encoder_layer(x) encode_features.append(x) # for detection head # [200, 176, 5] -> [200, 176, 2] out = self.conv_out(encode_features[-1]) spatial_features = out.dense() N, C, D, H, W = spatial_features.shape spatial_features = spatial_features.view(N, C * D, H, W) return spatial_features def make_encoder_layers(self, make_block, norm_cfg, in_channels): """make encoder layers using sparse convs. Args: make_block (method): a bounded function to build blocks norm_cfg (dict[str]): config of normalization layer in_channels (int): the number of encoder input channels Returns: int: the number of encoder output channels """ self.encoder_layers = spconv.SparseSequential() for i, blocks in enumerate(self.encoder_channels): blocks_list = [] for j, out_channels in enumerate(tuple(blocks)): padding = tuple(self.encoder_paddings[i])[j] # each stage started with a spconv layer # except the first stage if i != 0 and j == 0: blocks_list.append( make_block( in_channels, out_channels, 3, norm_cfg=norm_cfg, stride=2, padding=padding, indice_key=f'spconv{i + 1}', conv_type='SparseConv3d')) else: blocks_list.append( make_block( in_channels, out_channels, 3, norm_cfg=norm_cfg, padding=padding, indice_key=f'subm{i + 1}', conv_type='SubMConv3d')) in_channels = out_channels stage_name = f'encoder_layer{i + 1}' stage_layers = spconv.SparseSequential(*blocks_list) self.encoder_layers.add_module(stage_name, stage_layers) return out_channels