# Copyright (c) OpenMMLab. All rights reserved. import torch from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn from mmcv.runner import auto_fp16 from torch import nn as nn from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE from mmdet3d.registry import MODELS if IS_SPCONV2_AVAILABLE: from spconv.pytorch import SparseConvTensor, SparseSequential else: from mmcv.ops import SparseConvTensor, SparseSequential @MODELS.register_module() class SparseEncoder(nn.Module): r"""Sparse encoder for SECOND and Part-A2. Args: in_channels (int): The number of input channels. sparse_shape (list[int]): The sparse shape of input tensor. order (list[str], optional): Order of conv module. Defaults to ('conv', 'norm', 'act'). norm_cfg (dict, optional): Config of normalization layer. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01). base_channels (int, optional): Out channels for conv_input layer. Defaults to 16. output_channels (int, optional): Out channels for conv_out layer. Defaults to 128. encoder_channels (tuple[tuple[int]], optional): Convolutional channels of each encode block. Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)). encoder_paddings (tuple[tuple[int]], optional): Paddings of each encode block. Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)). block_type (str, optional): Type of the block to use. Defaults to 'conv_module'. """ def __init__(self, in_channels, sparse_shape, order=('conv', 'norm', 'act'), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), base_channels=16, output_channels=128, encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)), encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)), block_type='conv_module'): super().__init__() assert block_type in ['conv_module', 'basicblock'] self.sparse_shape = sparse_shape self.in_channels = in_channels self.order = order self.base_channels = base_channels self.output_channels = output_channels self.encoder_channels = encoder_channels self.encoder_paddings = encoder_paddings self.stage_num = len(self.encoder_channels) self.fp16_enabled = False # Spconv init all weight on its own assert isinstance(order, tuple) and len(order) == 3 assert set(order) == {'conv', 'norm', 'act'} if self.order[0] != 'conv': # pre activate self.conv_input = make_sparse_convmodule( in_channels, self.base_channels, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1', conv_type='SubMConv3d', order=('conv', )) else: # post activate self.conv_input = make_sparse_convmodule( in_channels, self.base_channels, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1', conv_type='SubMConv3d') encoder_out_channels = self.make_encoder_layers( make_sparse_convmodule, norm_cfg, self.base_channels, block_type=block_type) self.conv_out = make_sparse_convmodule( encoder_out_channels, self.output_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), norm_cfg=norm_cfg, padding=0, indice_key='spconv_down2', conv_type='SparseConv3d') @auto_fp16(apply_to=('voxel_features', )) def forward(self, voxel_features, coors, batch_size): """Forward of SparseEncoder. Args: voxel_features (torch.Tensor): Voxel features in shape (N, C). coors (torch.Tensor): Coordinates in shape (N, 4), the columns in the order of (batch_idx, z_idx, y_idx, x_idx). batch_size (int): Batch size. Returns: dict: Backbone features. """ coors = coors.int() input_sp_tensor = SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size) x = self.conv_input(input_sp_tensor) encode_features = [] for encoder_layer in self.encoder_layers: x = encoder_layer(x) encode_features.append(x) # for detection head # [200, 176, 5] -> [200, 176, 2] out = self.conv_out(encode_features[-1]) spatial_features = out.dense() N, C, D, H, W = spatial_features.shape spatial_features = spatial_features.view(N, C * D, H, W) return spatial_features def make_encoder_layers(self, make_block, norm_cfg, in_channels, block_type='conv_module', conv_cfg=dict(type='SubMConv3d')): """make encoder layers using sparse convs. Args: make_block (method): A bounded function to build blocks. norm_cfg (dict[str]): Config of normalization layer. in_channels (int): The number of encoder input channels. block_type (str, optional): Type of the block to use. Defaults to 'conv_module'. conv_cfg (dict, optional): Config of conv layer. Defaults to dict(type='SubMConv3d'). Returns: int: The number of encoder output channels. """ assert block_type in ['conv_module', 'basicblock'] self.encoder_layers = SparseSequential() for i, blocks in enumerate(self.encoder_channels): blocks_list = [] for j, out_channels in enumerate(tuple(blocks)): padding = tuple(self.encoder_paddings[i])[j] # each stage started with a spconv layer # except the first stage if i != 0 and j == 0 and block_type == 'conv_module': blocks_list.append( make_block( in_channels, out_channels, 3, norm_cfg=norm_cfg, stride=2, padding=padding, indice_key=f'spconv{i + 1}', conv_type='SparseConv3d')) elif block_type == 'basicblock': if j == len(blocks) - 1 and i != len( self.encoder_channels) - 1: blocks_list.append( make_block( in_channels, out_channels, 3, norm_cfg=norm_cfg, stride=2, padding=padding, indice_key=f'spconv{i + 1}', conv_type='SparseConv3d')) else: blocks_list.append( SparseBasicBlock( out_channels, out_channels, norm_cfg=norm_cfg, conv_cfg=conv_cfg)) else: blocks_list.append( make_block( in_channels, out_channels, 3, norm_cfg=norm_cfg, padding=padding, indice_key=f'subm{i + 1}', conv_type='SubMConv3d')) in_channels = out_channels stage_name = f'encoder_layer{i + 1}' stage_layers = SparseSequential(*blocks_list) self.encoder_layers.add_module(stage_name, stage_layers) return out_channels @MIDDLE_ENCODERS.register_module() class SparseEncoderSASSD(SparseEncoder): r"""Sparse encoder for `SASSD `_ Args: in_channels (int): The number of input channels. sparse_shape (list[int]): The sparse shape of input tensor. order (list[str], optional): Order of conv module. Defaults to ('conv', 'norm', 'act'). norm_cfg (dict, optional): Config of normalization layer. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01). base_channels (int, optional): Out channels for conv_input layer. Defaults to 16. output_channels (int, optional): Out channels for conv_out layer. Defaults to 128. encoder_channels (tuple[tuple[int]], optional): Convolutional channels of each encode block. Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)). encoder_paddings (tuple[tuple[int]], optional): Paddings of each encode block. Defaults to ((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)). block_type (str, optional): Type of the block to use. Defaults to 'conv_module'. """ def __init__(self, in_channels, sparse_shape, order=('conv', 'norm', 'act'), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), base_channels=16, output_channels=128, encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)), encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)), block_type='conv_module'): super(SparseEncoderSASSD, self).__init__( in_channels=in_channels, sparse_shape=sparse_shape, order=order, norm_cfg=norm_cfg, base_channels=base_channels, output_channels=output_channels, encoder_channels=encoder_channels, encoder_paddings=encoder_paddings, block_type=block_type) self.point_fc = nn.Linear(112, 64, bias=False) self.point_cls = nn.Linear(64, 1, bias=False) self.point_reg = nn.Linear(64, 3, bias=False) @auto_fp16(apply_to=('voxel_features', )) def forward(self, voxel_features, coors, batch_size, test_mode=False): """Forward of SparseEncoder. Args: voxel_features (torch.Tensor): Voxel features in shape (N, C). coors (torch.Tensor): Coordinates in shape (N, 4), the columns in the order of (batch_idx, z_idx, y_idx, x_idx). batch_size (int): Batch size. test_mode (bool, optional): Whether in test mode. Defaults to False. Returns: dict: Backbone features. tuple[torch.Tensor]: Mean feature value of the points, Classificaion result of the points, Regression offsets of the points. """ coors = coors.int() input_sp_tensor = SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size) x = self.conv_input(input_sp_tensor) encode_features = [] for encoder_layer in self.encoder_layers: x = encoder_layer(x) encode_features.append(x) # for detection head # [200, 176, 5] -> [200, 176, 2] out = self.conv_out(encode_features[-1]) spatial_features = out.dense() N, C, D, H, W = spatial_features.shape spatial_features = spatial_features.view(N, C * D, H, W) if test_mode: return spatial_features, None points_mean = torch.zeros_like(voxel_features) points_mean[:, 0] = coors[:, 0] points_mean[:, 1:] = voxel_features[:, :3] # auxiliary network p0 = self.make_auxiliary_points( encode_features[0], points_mean, offset=(0, -40., -3.), voxel_size=(.1, .1, .2)) p1 = self.make_auxiliary_points( encode_features[1], points_mean, offset=(0, -40., -3.), voxel_size=(.2, .2, .4)) p2 = self.make_auxiliary_points( encode_features[2], points_mean, offset=(0, -40., -3.), voxel_size=(.4, .4, .8)) pointwise = torch.cat([p0, p1, p2], dim=-1) pointwise = self.point_fc(pointwise) point_cls = self.point_cls(pointwise) point_reg = self.point_reg(pointwise) point_misc = (points_mean, point_cls, point_reg) return spatial_features, point_misc def get_auxiliary_targets(self, nxyz, gt_boxes3d, enlarge=1.0): """Get auxiliary target. Args: nxyz (torch.Tensor): Mean features of the points. gt_boxes3d (torch.Tensor): Coordinates in shape (N, 4), the columns in the order of (batch_idx, z_idx, y_idx, x_idx). enlarge (int, optional): Enlaged scale. Defaults to 1.0. Returns: tuple[torch.Tensor]: Label of the points and center offsets of the points. """ center_offsets = list() pts_labels = list() for i in range(len(gt_boxes3d)): boxes3d = gt_boxes3d[i].tensor.cpu() idx = torch.nonzero(nxyz[:, 0] == i).view(-1) new_xyz = nxyz[idx, 1:].cpu() boxes3d[:, 3:6] *= enlarge pts_in_flag, center_offset = self.calculate_pts_offsets( new_xyz, boxes3d) pts_label = pts_in_flag.max(0)[0].byte() pts_labels.append(pts_label) center_offsets.append(center_offset) center_offsets = torch.cat(center_offsets).cuda() pts_labels = torch.cat(pts_labels).to(center_offsets.device) return pts_labels, center_offsets def calculate_pts_offsets(self, points, boxes): """Find all boxes in which each point is, as well as the offsets from the box centers. Args: points (torch.Tensor): [M, 3], [x, y, z] in LiDAR/DEPTH coordinate boxes (torch.Tensor): [T, 7], num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz], (x, y, z) is the bottom center. Returns: tuple[torch.Tensor]: Point indices of boxes with the shape of (T, M). Default background = 0. And offsets from the box centers of points, if it belows to the box, with the shape of (M, 3). Default background = 0. """ boxes_num = len(boxes) pts_num = len(points) points = points.cuda() boxes = boxes.to(points.device) box_idxs_of_pts = points_in_boxes_all(points[None, ...], boxes[None, ...]) pts_indices = box_idxs_of_pts.squeeze(0).transpose(0, 1) center_offsets = torch.zeros_like(points).to(points.device) for i in range(boxes_num): for j in range(pts_num): if pts_indices[i][j] == 1: center_offsets[j][0] = points[j][0] - boxes[i][0] center_offsets[j][1] = points[j][1] - boxes[i][1] center_offsets[j][2] = ( points[j][2] - (boxes[i][2] + boxes[i][2] / 2.0)) return pts_indices.cpu(), center_offsets.cpu() def aux_loss(self, points, point_cls, point_reg, gt_bboxes): """Calculate auxiliary loss. Args: points (torch.Tensor): Mean feature value of the points. point_cls (torch.Tensor): Classificaion result of the points. point_reg (torch.Tensor): Regression offsets of the points. gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Ground truth boxes for each sample. Returns: dict: Backbone features. """ num_boxes = len(gt_bboxes) pts_labels, center_targets = self.get_auxiliary_targets( points, gt_bboxes) rpn_cls_target = pts_labels.long() pos = (pts_labels > 0).float() neg = (pts_labels == 0).float() pos_normalizer = pos.sum().clamp(min=1.0) cls_weights = pos + neg reg_weights = pos reg_weights = reg_weights / pos_normalizer aux_loss_cls = sigmoid_focal_loss( point_cls, rpn_cls_target, weight=cls_weights, avg_factor=pos_normalizer) aux_loss_cls /= num_boxes weight = reg_weights[..., None] aux_loss_reg = smooth_l1_loss(point_reg, center_targets, beta=1 / 9.) aux_loss_reg = torch.sum(aux_loss_reg * weight)[None] aux_loss_reg /= num_boxes aux_loss_cls, aux_loss_reg = [aux_loss_cls], [aux_loss_reg] return dict(aux_loss_cls=aux_loss_cls, aux_loss_reg=aux_loss_reg) def make_auxiliary_points(self, source_tensor, target, offset=(0., -40., -3.), voxel_size=(.05, .05, .1)): """Make auxiliary points for loss computation. Args: source_tensor (torch.Tensor): (M, C) features to be propigated. target (torch.Tensor): (N, 4) bxyz positions of the target features. offset (tuple[float], optional): Voxelization offset. Defaults to (0., -40., -3.) voxel_size (tuple[float], optional): Voxelization size. Defaults to (.05, .05, .1) Returns: torch.Tensor: (N, C) tensor of the features of the target features. """ # Tansfer tensor to points source = source_tensor.indices.float() offset = torch.Tensor(offset).to(source.device) voxel_size = torch.Tensor(voxel_size).to(source.device) source[:, 1:] = ( source[:, [3, 2, 1]] * voxel_size + offset + .5 * voxel_size) source_feats = source_tensor.features[None, ...].transpose(1, 2) # Interplate auxiliary points dist, idx = three_nn(target[None, ...], source[None, ...]) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm new_features = three_interpolate(source_feats.contiguous(), idx, weight) return new_features.squeeze(0).transpose(0, 1)