import torch from torch import nn from torch.nn import functional as F from mmdet3d.ops import DynamicScatter from .. import builder from ..registry import VOXEL_ENCODERS from ..utils import build_norm_layer from .utils import Empty, VFELayer, get_paddings_indicator @VOXEL_ENCODERS.register_module class VoxelFeatureExtractor(nn.Module): def __init__(self, num_input_features=4, use_norm=True, num_filters=[32, 128], with_distance=False, name='VoxelFeatureExtractor'): super(VoxelFeatureExtractor, self).__init__() self.name = name assert len(num_filters) == 2 num_input_features += 3 # add mean features if with_distance: num_input_features += 1 self._with_distance = with_distance self.vfe1 = VFELayer(num_input_features, num_filters[0], use_norm) self.vfe2 = VFELayer(num_filters[0], num_filters[1], use_norm) if use_norm: self.linear = nn.Linear(num_filters[1], num_filters[1], bias=False) self.norm = nn.BatchNorm1d(num_filters[1], eps=1e-3, momentum=0.01) else: self.linear = nn.Linear(num_filters[1], num_filters[1], bias=True) self.norm = Empty(num_filters[1]) def forward(self, features, num_voxels, **kwargs): # features: [concated_num_points, num_voxel_size, 3(4)] # num_voxels: [concated_num_points] # t = time.time() # torch.cuda.synchronize() points_mean = features[:, :, :3].sum( dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1) features_relative = features[:, :, :3] - points_mean if self._with_distance: points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True) features = torch.cat([features, features_relative, points_dist], dim=-1) else: features = torch.cat([features, features_relative], dim=-1) voxel_count = features.shape[1] mask = get_paddings_indicator(num_voxels, voxel_count, axis=0) mask = torch.unsqueeze(mask, -1).type_as(features) # mask = features.max(dim=2, keepdim=True)[0] != 0 # torch.cuda.synchronize() # print("vfe prep forward time", time.time() - t) x = self.vfe1(features) x *= mask x = self.vfe2(x) x *= mask x = self.linear(x) x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() x = F.relu(x) x *= mask # x: [concated_num_points, num_voxel_size, 128] voxelwise = torch.max(x, dim=1)[0] return voxelwise @VOXEL_ENCODERS.register_module class VoxelFeatureExtractorV2(nn.Module): def __init__(self, num_input_features=4, use_norm=True, num_filters=[32, 128], with_distance=False, name='VoxelFeatureExtractor'): super(VoxelFeatureExtractorV2, self).__init__() self.name = name assert len(num_filters) > 0 num_input_features += 3 if with_distance: num_input_features += 1 self._with_distance = with_distance num_filters = [num_input_features] + num_filters filters_pairs = [[num_filters[i], num_filters[i + 1]] for i in range(len(num_filters) - 1)] self.vfe_layers = nn.ModuleList( [VFELayer(i, o, use_norm) for i, o in filters_pairs]) if use_norm: self.linear = nn.Linear( num_filters[-1], num_filters[-1], bias=False) self.norm = nn.BatchNorm1d( num_filters[-1], eps=1e-3, momentum=0.01) else: self.linear = nn.Linear( num_filters[-1], num_filters[-1], bias=True) self.norm = Empty(num_filters[-1]) def forward(self, features, num_voxels, **kwargs): # features: [concated_num_points, num_voxel_size, 3(4)] # num_voxels: [concated_num_points] points_mean = features[:, :, :3].sum( dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1) features_relative = features[:, :, :3] - points_mean if self._with_distance: points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True) features = torch.cat([features, features_relative, points_dist], dim=-1) else: features = torch.cat([features, features_relative], dim=-1) voxel_count = features.shape[1] mask = get_paddings_indicator(num_voxels, voxel_count, axis=0) mask = torch.unsqueeze(mask, -1).type_as(features) for vfe in self.vfe_layers: features = vfe(features) features *= mask features = self.linear(features) features = self.norm(features.permute(0, 2, 1).contiguous()).permute( 0, 2, 1).contiguous() features = F.relu(features) features *= mask # x: [concated_num_points, num_voxel_size, 128] voxelwise = torch.max(features, dim=1)[0] return voxelwise @VOXEL_ENCODERS.register_module class VoxelFeatureExtractorV3(nn.Module): def __init__(self, num_input_features=4, use_norm=True, num_filters=[32, 128], with_distance=False, name='VoxelFeatureExtractor'): super(VoxelFeatureExtractorV3, self).__init__() self.name = name def forward(self, features, num_points, coors): # features: [concated_num_points, num_voxel_size, 3(4)] # num_points: [concated_num_points] points_mean = features[:, :, :4].sum( dim=1, keepdim=False) / num_points.type_as(features).view(-1, 1) return points_mean.contiguous() @VOXEL_ENCODERS.register_module class DynamicVFEV3(nn.Module): def __init__(self, num_input_features=4, voxel_size=(0.2, 0.2, 4), point_cloud_range=(0, -40, -3, 70.4, 40, 1)): super(DynamicVFEV3, self).__init__() self.scatter = DynamicScatter(voxel_size, point_cloud_range, True) @torch.no_grad() def forward(self, features, coors): # This function is used from the start of the voxelnet # num_points: [concated_num_points] features, features_coors = self.scatter(features, coors) return features, features_coors @VOXEL_ENCODERS.register_module class DynamicVFE(nn.Module): def __init__(self, num_input_features=4, num_filters=[], with_distance=False, with_cluster_center=False, with_voxel_center=False, voxel_size=(0.2, 0.2, 4), point_cloud_range=(0, -40, -3, 70.4, 40, 1), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), mode='max', fusion_layer=None, return_point_feats=False): super(DynamicVFE, self).__init__() assert len(num_filters) > 0 if with_cluster_center: num_input_features += 3 if with_voxel_center: num_input_features += 3 if with_distance: num_input_features += 3 self.num_input_features = num_input_features self._with_distance = with_distance self._with_cluster_center = with_cluster_center self._with_voxel_center = with_voxel_center self.return_point_feats = return_point_feats # Need pillar (voxel) size and x/y offset in order to calculate offset self.vx = voxel_size[0] self.vy = voxel_size[1] self.vz = voxel_size[2] self.x_offset = self.vx / 2 + point_cloud_range[0] self.y_offset = self.vy / 2 + point_cloud_range[1] self.z_offset = self.vz / 2 + point_cloud_range[2] self.point_cloud_range = point_cloud_range self.scatter = DynamicScatter(voxel_size, point_cloud_range, True) num_filters = [self.num_input_features] + list(num_filters) vfe_layers = [] for i in range(len(num_filters) - 1): in_filters = num_filters[i] out_filters = num_filters[i + 1] if i > 0: in_filters *= 2 norm_name, norm_layer = build_norm_layer(norm_cfg, out_filters) vfe_layers.append( nn.Sequential( nn.Linear(in_filters, out_filters, bias=False), norm_layer, nn.ReLU(inplace=True))) self.vfe_layers = nn.ModuleList(vfe_layers) self.num_vfe = len(vfe_layers) self.vfe_scatter = DynamicScatter(voxel_size, point_cloud_range, (mode != 'max')) self.cluster_scatter = DynamicScatter( voxel_size, point_cloud_range, average_points=True) self.fusion_layer = None if fusion_layer is not None: self.fusion_layer = builder.build_fusion_layer(fusion_layer) def map_voxel_center_to_point(self, pts_coors, voxel_mean, voxel_coors): # Step 1: scatter voxel into canvas # Calculate necessary things for canvas creation canvas_z = int( (self.point_cloud_range[5] - self.point_cloud_range[2]) / self.vz) canvas_y = int( (self.point_cloud_range[4] - self.point_cloud_range[1]) / self.vy) canvas_x = int( (self.point_cloud_range[3] - self.point_cloud_range[0]) / self.vx) # canvas_channel = voxel_mean.size(1) batch_size = pts_coors[-1, 0] + 1 canvas_len = canvas_z * canvas_y * canvas_x * batch_size # Create the canvas for this sample canvas = voxel_mean.new_zeros(canvas_len, dtype=torch.long) # Only include non-empty pillars indices = ( voxel_coors[:, 0] * canvas_z * canvas_y * canvas_x + voxel_coors[:, 1] * canvas_y * canvas_x + voxel_coors[:, 2] * canvas_x + voxel_coors[:, 3]) # Scatter the blob back to the canvas canvas[indices.long()] = torch.arange( start=0, end=voxel_mean.size(0), device=voxel_mean.device) # Step 2: get voxel mean for each point voxel_index = ( pts_coors[:, 0] * canvas_z * canvas_y * canvas_x + pts_coors[:, 1] * canvas_y * canvas_x + pts_coors[:, 2] * canvas_x + pts_coors[:, 3]) voxel_inds = canvas[voxel_index.long()] center_per_point = voxel_mean[voxel_inds, ...] return center_per_point def forward(self, features, coors, points=None, img_feats=None, img_meta=None): """ features (torch.Tensor): NxC coors (torch.Tensor): Nx(1+NDim) """ features_ls = [features] # Find distance of x, y, and z from cluster center if self._with_cluster_center: voxel_mean, mean_coors = self.cluster_scatter(features, coors) points_mean = self.map_voxel_center_to_point( coors, voxel_mean, mean_coors) # TODO: maybe also do cluster for reflectivity f_cluster = features[:, :3] - points_mean[:, :3] features_ls.append(f_cluster) # Find distance of x, y, and z from pillar center if self._with_voxel_center: f_center = features.new_zeros(size=(features.size(0), 3)) f_center[:, 0] = features[:, 0] - ( coors[:, 3].type_as(features) * self.vx + self.x_offset) f_center[:, 1] = features[:, 1] - ( coors[:, 2].type_as(features) * self.vy + self.y_offset) f_center[:, 2] = features[:, 2] - ( coors[:, 1].type_as(features) * self.vz + self.z_offset) features_ls.append(f_center) if self._with_distance: points_dist = torch.norm(features[:, :3], 2, 1, keepdim=True) features_ls.append(points_dist) # Combine together feature decorations features = torch.cat(features_ls, dim=-1) for i, vfe in enumerate(self.vfe_layers): point_feats = vfe(features) if (i == len(self.vfe_layers) - 1 and self.fusion_layer is not None and img_feats is not None): point_feats = self.fusion_layer(img_feats, points, point_feats, img_meta) voxel_feats, voxel_coors = self.vfe_scatter(point_feats, coors) if i != len(self.vfe_layers) - 1: # need to concat voxel feats if it is not the last vfe feat_per_point = self.map_voxel_center_to_point( coors, voxel_feats, voxel_coors) features = torch.cat([point_feats, feat_per_point], dim=1) if self.return_point_feats: return point_feats return voxel_feats, voxel_coors @VOXEL_ENCODERS.register_module class HardVFE(nn.Module): def __init__(self, num_input_features=4, num_filters=[], with_distance=False, with_cluster_center=False, with_voxel_center=False, voxel_size=(0.2, 0.2, 4), point_cloud_range=(0, -40, -3, 70.4, 40, 1), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), mode='max', fusion_layer=None, return_point_feats=False): super(HardVFE, self).__init__() assert len(num_filters) > 0 if with_cluster_center: num_input_features += 3 if with_voxel_center: num_input_features += 3 if with_distance: num_input_features += 3 self.num_input_features = num_input_features self._with_distance = with_distance self._with_cluster_center = with_cluster_center self._with_voxel_center = with_voxel_center self.return_point_feats = return_point_feats # Need pillar (voxel) size and x/y offset to calculate pillar offset self.vx = voxel_size[0] self.vy = voxel_size[1] self.vz = voxel_size[2] self.x_offset = self.vx / 2 + point_cloud_range[0] self.y_offset = self.vy / 2 + point_cloud_range[1] self.z_offset = self.vz / 2 + point_cloud_range[2] self.point_cloud_range = point_cloud_range self.scatter = DynamicScatter(voxel_size, point_cloud_range, True) num_filters = [self.num_input_features] + list(num_filters) vfe_layers = [] for i in range(len(num_filters) - 1): in_filters = num_filters[i] out_filters = num_filters[i + 1] if i > 0: in_filters *= 2 # TODO: pass norm_cfg to VFE # norm_name, norm_layer = build_norm_layer(norm_cfg, out_filters) if i == (len(num_filters) - 2): cat_max = False max_out = True if fusion_layer: max_out = False else: max_out = True cat_max = True vfe_layers.append( VFELayer( in_filters, out_filters, norm_cfg=norm_cfg, max_out=max_out, cat_max=cat_max)) self.vfe_layers = nn.ModuleList(vfe_layers) self.num_vfe = len(vfe_layers) self.fusion_layer = None if fusion_layer is not None: self.fusion_layer = builder.build_fusion_layer(fusion_layer) def forward(self, features, num_points, coors, img_feats=None, img_meta=None): """ features (torch.Tensor): NxMxC coors (torch.Tensor): Nx(1+NDim) """ features_ls = [features] # Find distance of x, y, and z from cluster center if self._with_cluster_center: points_mean = ( features[:, :, :3].sum(dim=1, keepdim=True) / num_points.type_as(features).view(-1, 1, 1)) # TODO: maybe also do cluster for reflectivity f_cluster = features[:, :, :3] - points_mean features_ls.append(f_cluster) # Find distance of x, y, and z from pillar center if self._with_voxel_center: f_center = features.new_zeros( size=(features.size(0), features.size(1), 3)) f_center[:, :, 0] = features[:, :, 0] - ( coors[:, 3].type_as(features).unsqueeze(1) * self.vx + self.x_offset) f_center[:, :, 1] = features[:, :, 1] - ( coors[:, 2].type_as(features).unsqueeze(1) * self.vy + self.y_offset) f_center[:, :, 2] = features[:, :, 2] - ( coors[:, 1].type_as(features).unsqueeze(1) * self.vz + self.z_offset) features_ls.append(f_center) if self._with_distance: points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True) features_ls.append(points_dist) # Combine together feature decorations voxel_feats = torch.cat(features_ls, dim=-1) # The feature decorations were calculated without regard to whether # pillar was empty. # Need to ensure that empty voxels remain set to zeros. voxel_count = voxel_feats.shape[1] mask = get_paddings_indicator(num_points, voxel_count, axis=0) voxel_feats *= mask.unsqueeze(-1).type_as(voxel_feats) for i, vfe in enumerate(self.vfe_layers): voxel_feats = vfe(voxel_feats) if torch.isnan(voxel_feats).any(): import pdb pdb.set_trace() if (self.fusion_layer is not None and img_feats is not None): voxel_feats = self.fusion_with_mask(features, mask, voxel_feats, coors, img_feats, img_meta) if torch.isnan(voxel_feats).any(): import pdb pdb.set_trace() return voxel_feats def fusion_with_mask(self, features, mask, voxel_feats, coors, img_feats, img_meta): # the features is consist of a batch of points batch_size = coors[-1, 0] + 1 points = [] for i in range(batch_size): single_mask = (coors[:, 0] == i) points.append(features[single_mask][mask[single_mask]]) point_feats = voxel_feats[mask] if torch.isnan(point_feats).any(): import pdb pdb.set_trace() point_feats = self.fusion_layer(img_feats, points, point_feats, img_meta) if torch.isnan(point_feats).any(): import pdb pdb.set_trace() voxel_canvas = voxel_feats.new_zeros( size=(voxel_feats.size(0), voxel_feats.size(1), point_feats.size(-1))) voxel_canvas[mask] = point_feats out = torch.max(voxel_canvas, dim=1)[0] if torch.isnan(out).any(): import pdb pdb.set_trace() return out