Commit fa39f1c5 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

Support sectorized-proposal-centric (SPC) keypoint sampling

parent 8922371e
import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y): ...@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y):
return ans return ans
def sample_points_with_roi(rois, points, sample_radius_with_roi, num_max_points_of_part=200000):
"""
Args:
rois: (M, 7 + C)
points: (N, 3)
sample_radius_with_roi:
num_max_points_of_part:
Returns:
sampled_points: (N_out, 3)
"""
if points.shape[0] < num_max_points_of_part:
distance = (points[:, None, :] - rois[None, :, 0:3]).norm(dim=-1)
min_dis, min_dis_roi_idx = distance.min(dim=-1)
roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1)
point_mask = min_dis < roi_max_dim + sample_radius_with_roi
else:
start_idx = 0
point_mask_list = []
while start_idx < points.shape[0]:
distance = (points[start_idx:start_idx + num_max_points_of_part, None, :] - rois[None, :, 0:3]).norm(dim=-1)
min_dis, min_dis_roi_idx = distance.min(dim=-1)
roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1)
cur_point_mask = min_dis < roi_max_dim + sample_radius_with_roi
point_mask_list.append(cur_point_mask)
start_idx += num_max_points_of_part
point_mask = torch.cat(point_mask_list, dim=0)
sampled_points = points[:1] if point_mask.sum() == 0 else points[point_mask, :]
return sampled_points, point_mask
def sector_fps(points, num_sampled_points, num_sectors):
"""
Args:
points: (N, 3)
num_sampled_points: int
num_sectors: int
Returns:
sampled_points: (N_out, 3)
"""
sector_size = np.pi * 2 / num_sectors
point_angles = torch.atan2(points[:, 1], points[:, 0]) + np.pi
sector_idx = (point_angles / sector_size).floor().clamp(min=0, max=num_sectors)
xyz_points_list = []
xyz_batch_cnt = []
num_sampled_points_list = []
for k in range(num_sectors):
mask = (sector_idx == k)
cur_num_points = mask.sum().item()
if cur_num_points > 0:
xyz_points_list.append(points[mask])
xyz_batch_cnt.append(cur_num_points)
ratio = cur_num_points / points.shape[0]
num_sampled_points_list.append(
min(cur_num_points, math.ceil(ratio * num_sampled_points))
)
if len(xyz_batch_cnt) == 0:
xyz_points_list.append(points)
xyz_batch_cnt.append(len(points))
num_sampled_points_list.append(num_sampled_points)
print(f'Warning: empty sector points detected in SectorFPS: points.shape={points.shape}')
xyz = torch.cat(xyz_points_list, dim=0)
xyz_batch_cnt = torch.tensor(xyz_batch_cnt, device=points.device).int()
sampled_points_batch_cnt = torch.tensor(num_sampled_points_list, device=points.device).int()
sampled_pt_idxs = pointnet2_stack_utils.stack_farthest_point_sample(
xyz.contiguous(), xyz_batch_cnt, sampled_points_batch_cnt
).long()
sampled_points = xyz[sampled_pt_idxs]
return sampled_points
class VoxelSetAbstraction(nn.Module): class VoxelSetAbstraction(nn.Module):
def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None, def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None,
num_rawpoint_features=None, **kwargs): num_rawpoint_features=None, **kwargs):
...@@ -100,23 +181,64 @@ class VoxelSetAbstraction(nn.Module): ...@@ -100,23 +181,64 @@ class VoxelSetAbstraction(nn.Module):
self.num_point_features_before_fusion = c_in self.num_point_features_before_fusion = c_in
def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride): def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride):
x_idxs = (keypoints[:, :, 0] - self.point_cloud_range[0]) / self.voxel_size[0] """
y_idxs = (keypoints[:, :, 1] - self.point_cloud_range[1]) / self.voxel_size[1] Args:
keypoints: (N1 + N2 + ..., 4)
bev_features: (B, C, H, W)
batch_size:
bev_stride:
Returns:
point_bev_features: (N1 + N2 + ..., C)
"""
x_idxs = (keypoints[:, 1] - self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[:, 2] - self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_stride x_idxs = x_idxs / bev_stride
y_idxs = y_idxs / bev_stride y_idxs = y_idxs / bev_stride
point_bev_features_list = [] point_bev_features_list = []
for k in range(batch_size): for k in range(batch_size):
cur_x_idxs = x_idxs[k] bs_mask = (keypoints[:, 0] == k)
cur_y_idxs = y_idxs[k]
cur_x_idxs = x_idxs[bs_mask]
cur_y_idxs = y_idxs[bs_mask]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C) cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs) point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features.unsqueeze(dim=0)) point_bev_features_list.append(point_bev_features)
point_bev_features = torch.cat(point_bev_features_list, dim=0) # (B, N, C0) point_bev_features = torch.cat(point_bev_features_list, dim=0) # (N1 + N2 + ..., C)
return point_bev_features return point_bev_features
def sectorized_proposal_centric_sampling(self, roi_boxes, points):
"""
Args:
roi_boxes: (M, 7 + C)
points: (N, 3)
Returns:
sampled_points: (N_out, 3)
"""
sampled_points, _ = sample_points_with_roi(
rois=roi_boxes, points=points,
sample_radius_with_roi=self.model_cfg.SPC.SAMPLE_RADIUS_WITH_ROI,
num_max_points_of_part=self.model_cfg.SPC.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000)
)
sampled_points = sector_fps(
points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS,
num_sectors=self.model_cfg.SPC.NUM_SECTORS
)
return sampled_points
def get_sampled_points(self, batch_dict): def get_sampled_points(self, batch_dict):
"""
Args:
batch_dict:
Returns:
keypoints: (N1 + N2 + ..., 4), where 4 indicates [bs_idx, x, y, z]
"""
batch_size = batch_dict['batch_size'] batch_size = batch_dict['batch_size']
if self.model_cfg.POINT_SOURCE == 'raw_points': if self.model_cfg.POINT_SOURCE == 'raw_points':
src_points = batch_dict['points'][:, 1:4] src_points = batch_dict['points'][:, 1:4]
...@@ -147,16 +269,75 @@ class VoxelSetAbstraction(nn.Module): ...@@ -147,16 +269,75 @@ class VoxelSetAbstraction(nn.Module):
keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0) keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0)
elif self.model_cfg.SAMPLE_METHOD == 'FastFPS': elif self.model_cfg.SAMPLE_METHOD == 'SPC':
raise NotImplementedError cur_keypoints = self.sectorized_proposal_centric_sampling(
roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points
)
bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx
keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1)
else: else:
raise NotImplementedError raise NotImplementedError
keypoints_list.append(keypoints) keypoints_list.append(keypoints)
keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) or (N1 + N2 + ..., 4)
if len(keypoints.shape) == 3:
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1, 1)
keypoints = torch.cat((batch_idx.float(), keypoints.view(-1, 3)), dim=1)
return keypoints return keypoints
@staticmethod
def aggregate_keypoint_features_from_one_source(
batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt,
filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=None, rois=None
):
"""
Args:
aggregate_func:
xyz: (N, 3)
xyz_features: (N, C)
xyz_bs_idxs: (N)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size), [N1, N2, ...]
filter_neighbors_with_roi: True/False
radius_of_neighbor: float
num_max_points_of_part: int
rois: (batch_size, num_rois, 7 + C)
Returns:
"""
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
if filter_neighbors_with_roi:
point_features = torch.cat((xyz, xyz_features), dim=0) if xyz_features is not None else xyz
point_features_list = []
for bs_idx in range(batch_size):
bs_mask = (xyz_bs_idxs == bs_idx)
_, valid_mask = sample_points_with_roi(
rois=rois[bs_idx], points=xyz[bs_mask],
sample_radius_with_roi=radius_of_neighbor, num_max_points_of_part=num_max_points_of_part,
)
point_features_list.append(point_features[bs_mask][valid_mask])
xyz_batch_cnt[bs_idx] = valid_mask.sum()
valid_point_features = torch.cat(point_features_list, dim=0)
xyz = valid_point_features[:, 0:3]
xyz_features = valid_point_features[:, 3:] if xyz_features is not None else None
else:
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (xyz_bs_idxs == bs_idx).sum()
pooled_points, pooled_features = aggregate_func(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=xyz_features,
)
return pooled_features
def forward(self, batch_dict): def forward(self, batch_dict):
""" """
Args: Args:
...@@ -185,56 +366,53 @@ class VoxelSetAbstraction(nn.Module): ...@@ -185,56 +366,53 @@ class VoxelSetAbstraction(nn.Module):
) )
point_features_list.append(point_bev_features) point_features_list.append(point_bev_features)
batch_size, num_keypoints, _ = keypoints.shape batch_size = batch_dict['batch_size']
new_xyz = keypoints.view(-1, 3)
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int().fill_(num_keypoints) new_xyz = keypoints[:, 1:4].contiguous()
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int()
for k in range(batch_size):
new_xyz_batch_cnt[k] = (keypoints[:, 0] == k).sum()
if 'raw_points' in self.model_cfg.FEATURES_SOURCE: if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
raw_points = batch_dict['points'] raw_points = batch_dict['points']
xyz = raw_points[:, 1:4]
xyz_batch_cnt = xyz.new_zeros(batch_size).int() pooled_features = self.aggregate_keypoint_features_from_one_source(
for bs_idx in range(batch_size): batch_size=batch_size, aggregate_func=self.SA_rawpoints,
xyz_batch_cnt[bs_idx] = (raw_points[:, 0] == bs_idx).sum() xyz=raw_points[:, 1:4],
point_features = raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None xyz_features=raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None,
xyz_bs_idxs=raw_points[:, 0],
pooled_points, pooled_features = self.SA_rawpoints( new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt,
xyz=xyz.contiguous(), filter_neighbors_with_roi=self.model_cfg.SA_LAYER['raw_points'].get('FILTER_NEIGHBOR_WITH_ROI', False),
xyz_batch_cnt=xyz_batch_cnt, radius_of_neighbor=self.model_cfg.SA_LAYER['raw_points'].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None),
new_xyz=new_xyz, rois=batch_dict.get('rois', None)
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=point_features,
) )
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1)) point_features_list.append(pooled_features)
for k, src_name in enumerate(self.SA_layer_names): for k, src_name in enumerate(self.SA_layer_names):
cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices
cur_features = batch_dict['multi_scale_3d_features'][src_name].features.contiguous()
xyz = common_utils.get_voxel_centers( xyz = common_utils.get_voxel_centers(
cur_coords[:, 1:4], cur_coords[:, 1:4], downsample_times=self.downsample_times_map[src_name],
downsample_times=self.downsample_times_map[src_name], voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range
voxel_size=self.voxel_size,
point_cloud_range=self.point_cloud_range
) )
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size): pooled_features = self.aggregate_keypoint_features_from_one_source(
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum() batch_size=batch_size, aggregate_func=self.SA_layers[k],
xyz=xyz.contiguous(), xyz_features=cur_features, xyz_bs_idxs=cur_coords[:, 0],
pooled_points, pooled_features = self.SA_layers[k]( new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt,
xyz=xyz.contiguous(), filter_neighbors_with_roi=self.model_cfg.SA_LAYER[src_name].get('FILTER_NEIGHBOR_WITH_ROI', False),
xyz_batch_cnt=xyz_batch_cnt, radius_of_neighbor=self.model_cfg.SA_LAYER[src_name].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None),
new_xyz=new_xyz, rois=batch_dict.get('rois', None)
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=batch_dict['multi_scale_3d_features'][src_name].features.contiguous(),
) )
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1))
point_features = torch.cat(point_features_list, dim=2) point_features_list.append(pooled_features)
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1) point_features = torch.cat(point_features_list, dim=-1)
point_coords = torch.cat((batch_idx.view(-1, 1).float(), keypoints.view(-1, 3)), dim=1)
batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1]) batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1])
point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1])) point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1]))
batch_dict['point_features'] = point_features # (BxN, C) batch_dict['point_features'] = point_features # (BxN, C)
batch_dict['point_coords'] = point_coords # (BxN, 4) batch_dict['point_coords'] = keypoints # (BxN, 4)
return batch_dict return batch_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment