import torch from mmcv.cnn import build_norm_layer from torch import nn from mmdet3d.ops import DynamicScatter from ..registry import VOXEL_ENCODERS from .utils import PFNLayer, get_paddings_indicator @VOXEL_ENCODERS.register_module() class PillarFeatureNet(nn.Module): """Pillar Feature Net. The network prepares the pillar features and performs forward pass through PFNLayers. Args: in_channels (int, optional): Number of input features, either x, y, z or x, y, z, r. Defaults to 4. feat_channels (tuple, optional): Number of features in each of the N PFNLayers. Defaults to (64, ). with_distance (bool, optional): Whether to include Euclidean distance to points. Defaults to False. with_cluster_center (bool, optional): [description]. Defaults to True. with_voxel_center (bool, optional): [description]. Defaults to True. voxel_size (tuple[float], optional): Size of voxels, only utilize x and y size. Defaults to (0.2, 0.2, 4). point_cloud_range (tuple[float], optional): Point cloud range, only utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1). norm_cfg ([type], optional): [description]. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01). mode (str, optional): The mode to gather point features. Options are 'max' or 'avg'. Defaults to 'max'. """ def __init__(self, in_channels=4, feat_channels=(64, ), with_distance=False, with_cluster_center=True, with_voxel_center=True, voxel_size=(0.2, 0.2, 4), point_cloud_range=(0, -40, -3, 70.4, 40, 1), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), mode='max'): super(PillarFeatureNet, self).__init__() assert len(feat_channels) > 0 if with_cluster_center: in_channels += 3 if with_voxel_center: in_channels += 2 if with_distance: in_channels += 1 self._with_distance = with_distance self._with_cluster_center = with_cluster_center self._with_voxel_center = with_voxel_center # Create PillarFeatureNet layers self.in_channels = in_channels feat_channels = [in_channels] + list(feat_channels) pfn_layers = [] for i in range(len(feat_channels) - 1): in_filters = feat_channels[i] out_filters = feat_channels[i + 1] if i < len(feat_channels) - 2: last_layer = False else: last_layer = True pfn_layers.append( PFNLayer( in_filters, out_filters, norm_cfg=norm_cfg, last_layer=last_layer, mode=mode)) self.pfn_layers = nn.ModuleList(pfn_layers) # Need pillar (voxel) size and x/y offset in order to calculate offset self.vx = voxel_size[0] self.vy = voxel_size[1] self.x_offset = self.vx / 2 + point_cloud_range[0] self.y_offset = self.vy / 2 + point_cloud_range[1] self.point_cloud_range = point_cloud_range def forward(self, features, num_points, coors): """Forward function. Args: features (torch.Tensor): Point features or raw points in shape (N, M, C). num_points (torch.Tensor): Number of points in each pillar. coors (torch.Tensor): Coordinates of each voxel Returns: torch.Tensor: Features of pillars. """ features_ls = [features] # Find distance of x, y, and z from cluster center if self._with_cluster_center: points_mean = features[:, :, :3].sum( dim=1, keepdim=True) / num_points.type_as(features).view( -1, 1, 1) f_cluster = features[:, :, :3] - points_mean features_ls.append(f_cluster) # Find distance of x, y, and z from pillar center if self._with_voxel_center: f_center = features[:, :, :2] f_center[:, :, 0] = f_center[:, :, 0] - ( coors[:, 3].type_as(features).unsqueeze(1) * self.vx + self.x_offset) f_center[:, :, 1] = f_center[:, :, 1] - ( coors[:, 2].type_as(features).unsqueeze(1) * self.vy + self.y_offset) features_ls.append(f_center) if self._with_distance: points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True) features_ls.append(points_dist) # Combine together feature decorations features = torch.cat(features_ls, dim=-1) # The feature decorations were calculated without regard to whether # pillar was empty. Need to ensure that # empty pillars remain set to zeros. voxel_count = features.shape[1] mask = get_paddings_indicator(num_points, voxel_count, axis=0) mask = torch.unsqueeze(mask, -1).type_as(features) features *= mask for pfn in self.pfn_layers: features = pfn(features, num_points) return features.squeeze() @VOXEL_ENCODERS.register_module() class DynamicPillarFeatureNet(PillarFeatureNet): """Pillar Feature Net using dynamic voxelization. The network prepares the pillar features and performs forward pass through PFNLayers. The main difference is that it is used for dynamic voxels, which contains different number of points inside a voxel without limits. Args: in_channels (int, optional): Number of input features, either x, y, z or x, y, z, r. Defaults to 4. feat_channels (tuple, optional): Number of features in each of the N PFNLayers. Defaults to (64, ). with_distance (bool, optional): Whether to include Euclidean distance to points. Defaults to False. with_cluster_center (bool, optional): [description]. Defaults to True. with_voxel_center (bool, optional): [description]. Defaults to True. voxel_size (tuple[float], optional): Size of voxels, only utilize x and y size. Defaults to (0.2, 0.2, 4). point_cloud_range (tuple[float], optional): Point cloud range, only utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1). norm_cfg ([type], optional): [description]. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01). mode (str, optional): The mode to gather point features. Options are 'max' or 'avg'. Defaults to 'max'. """ def __init__(self, in_channels=4, feat_channels=(64, ), with_distance=False, with_cluster_center=True, with_voxel_center=True, voxel_size=(0.2, 0.2, 4), point_cloud_range=(0, -40, -3, 70.4, 40, 1), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), mode='max'): super(DynamicPillarFeatureNet, self).__init__( in_channels, feat_channels, with_distance, with_cluster_center=with_cluster_center, with_voxel_center=with_voxel_center, voxel_size=voxel_size, point_cloud_range=point_cloud_range, norm_cfg=norm_cfg, mode=mode) feat_channels = [self.in_channels] + list(feat_channels) pfn_layers = [] # TODO: currently only support one PFNLayer for i in range(len(feat_channels) - 1): in_filters = feat_channels[i] out_filters = feat_channels[i + 1] if i > 0: in_filters *= 2 norm_name, norm_layer = build_norm_layer(norm_cfg, out_filters) pfn_layers.append( nn.Sequential( nn.Linear(in_filters, out_filters, bias=False), norm_layer, nn.ReLU(inplace=True))) self.num_pfn = len(pfn_layers) self.pfn_layers = nn.ModuleList(pfn_layers) self.pfn_scatter = DynamicScatter(voxel_size, point_cloud_range, (mode != 'max')) self.cluster_scatter = DynamicScatter( voxel_size, point_cloud_range, average_points=True) def map_voxel_center_to_point(self, pts_coors, voxel_mean, voxel_coors): """Map the centers of voxels to its corresponding points. Args: pts_coors (torch.Tensor): The coordinates of each points, shape (M, 3), where M is the number of points. voxel_mean (torch.Tensor): The mean or aggreagated features of a voxel, shape (N, C), where N is the number of voxels. voxel_coors (torch.Tensor): The coordinates of each voxel. Returns: torch.Tensor: Corresponding voxel centers of each points, shape (M, C), where M is the numver of points. """ # Step 1: scatter voxel into canvas # Calculate necessary things for canvas creation canvas_y = int( (self.point_cloud_range[4] - self.point_cloud_range[1]) / self.vy) canvas_x = int( (self.point_cloud_range[3] - self.point_cloud_range[0]) / self.vx) canvas_channel = voxel_mean.size(1) batch_size = pts_coors[-1, 0] + 1 canvas_len = canvas_y * canvas_x * batch_size # Create the canvas for this sample canvas = voxel_mean.new_zeros(canvas_channel, canvas_len) # Only include non-empty pillars indices = ( voxel_coors[:, 0] * canvas_y * canvas_x + voxel_coors[:, 2] * canvas_x + voxel_coors[:, 3]) # Scatter the blob back to the canvas canvas[:, indices.long()] = voxel_mean.t() # Step 2: get voxel mean for each point voxel_index = ( pts_coors[:, 0] * canvas_y * canvas_x + pts_coors[:, 2] * canvas_x + pts_coors[:, 3]) center_per_point = canvas[:, voxel_index.long()].t() return center_per_point def forward(self, features, coors): """Forward function. Args: features (torch.Tensor): Point features or raw points in shape (N, M, C). coors (torch.Tensor): Coordinates of each voxel Returns: torch.Tensor: Features of pillars. """ features_ls = [features] # Find distance of x, y, and z from cluster center if self._with_cluster_center: voxel_mean, mean_coors = self.cluster_scatter(features, coors) points_mean = self.map_voxel_center_to_point( coors, voxel_mean, mean_coors) # TODO: maybe also do cluster for reflectivity f_cluster = features[:, :3] - points_mean[:, :3] features_ls.append(f_cluster) # Find distance of x, y, and z from pillar center if self._with_voxel_center: f_center = features.new_zeros(size=(features.size(0), 2)) f_center[:, 0] = features[:, 0] - ( coors[:, 3].type_as(features) * self.vx + self.x_offset) f_center[:, 1] = features[:, 1] - ( coors[:, 2].type_as(features) * self.vy + self.y_offset) features_ls.append(f_center) if self._with_distance: points_dist = torch.norm(features[:, :3], 2, 1, keepdim=True) features_ls.append(points_dist) # Combine together feature decorations features = torch.cat(features_ls, dim=-1) for i, pfn in enumerate(self.pfn_layers): point_feats = pfn(features) voxel_feats, voxel_coors = self.pfn_scatter(point_feats, coors) if i != len(self.pfn_layers) - 1: # need to concat voxel feats if it is not the last pfn feat_per_point = self.map_voxel_center_to_point( coors, voxel_feats, voxel_coors) features = torch.cat([point_feats, feat_per_point], dim=1) return voxel_feats, voxel_coors