Commit 13789796 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

Support PV-RCNN++ frameworks, support VectorPool aggregation

parent 183d353a
......@@ -139,38 +139,31 @@ class VoxelSetAbstraction(nn.Module):
if src_name in ['bev', 'raw_points']:
continue
self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR
mlps = SA_cfg[src_name].MLPS
for k in range(len(mlps)):
mlps[k] = [mlps[k][0]] + mlps[k]
cur_layer = pointnet2_stack_modules.StackSAModuleMSG(
radii=SA_cfg[src_name].POOL_RADIUS,
nsamples=SA_cfg[src_name].NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method='max_pool',
if SA_cfg[src_name].get('INPUT_CHANNELS', None) is None:
input_channels = SA_cfg[src_name].MLPS[0][0] \
if isinstance(SA_cfg[src_name].MLPS[0], list) else SA_cfg[src_name].MLPS[0]
else:
input_channels = SA_cfg[src_name]['INPUT_CHANNELS']
cur_layer, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
input_channels=input_channels, config=SA_cfg[src_name]
)
self.SA_layers.append(cur_layer)
self.SA_layer_names.append(src_name)
c_in += sum([x[-1] for x in mlps])
c_in += cur_num_c_out
if 'bev' in self.model_cfg.FEATURES_SOURCE:
c_bev = num_bev_features
c_in += c_bev
if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
mlps = SA_cfg['raw_points'].MLPS
for k in range(len(mlps)):
mlps[k] = [num_rawpoint_features - 3] + mlps[k]
self.SA_rawpoints = pointnet2_stack_modules.StackSAModuleMSG(
radii=SA_cfg['raw_points'].POOL_RADIUS,
nsamples=SA_cfg['raw_points'].NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method='max_pool'
self.SA_rawpoints, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
input_channels=num_rawpoint_features - 3, config=SA_cfg['raw_points']
)
c_in += sum([x[-1] for x in mlps])
c_in += cur_num_c_out
self.vsa_point_feature_fusion = nn.Sequential(
nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False),
......
......@@ -8,6 +8,7 @@ from .second_net_iou import SECONDNetIoU
from .caddn import CaDDN
from .voxel_rcnn import VoxelRCNN
from .centerpoint import CenterPoint
from .pv_rcnn_plusplus import PVRCNNPlusPlus
__all__ = {
'Detector3DTemplate': Detector3DTemplate,
......@@ -19,7 +20,8 @@ __all__ = {
'SECONDNetIoU': SECONDNetIoU,
'CaDDN': CaDDN,
'VoxelRCNN': VoxelRCNN,
'CenterPoint': CenterPoint
'CenterPoint': CenterPoint,
'PVRCNNPlusPlus': PVRCNNPlusPlus
}
......
......@@ -10,21 +10,12 @@ class PVRCNNHead(RoIHeadTemplate):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
mlps = self.model_cfg.ROI_GRID_POOL.MLPS
for k in range(len(mlps)):
mlps[k] = [input_channels] + mlps[k]
self.roi_grid_pool_layer = pointnet2_stack_modules.StackSAModuleMSG(
radii=self.model_cfg.ROI_GRID_POOL.POOL_RADIUS,
nsamples=self.model_cfg.ROI_GRID_POOL.NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method=self.model_cfg.ROI_GRID_POOL.POOL_METHOD,
self.roi_grid_pool_layer, num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
input_channels=input_channels, config=self.model_cfg.ROI_GRID_POOL
)
GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
c_out = sum([x[-1] for x in mlps])
pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * c_out
pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * num_c_out
shared_fc_list = []
for k in range(0, self.model_cfg.SHARED_FC.__len__()):
......@@ -150,9 +141,11 @@ class PVRCNNHead(RoIHeadTemplate):
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
targets_dict = batch_dict.get('roi_targets_dict', None)
if targets_dict is None:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
# RoI aware pooling
pooled_features = self.roi_grid_pool(batch_dict) # (BxN, 6x6x6, C)
......
......@@ -7,6 +7,26 @@ import torch.nn.functional as F
from . import pointnet2_utils
def build_local_aggregation_module(input_channels, config):
local_aggregation_name = config.get('NAME', 'StackSAModuleMSG')
if local_aggregation_name == 'StackSAModuleMSG':
mlps = config.MLPS
for k in range(len(mlps)):
mlps[k] = [input_channels] + mlps[k]
cur_layer = StackSAModuleMSG(
radii=config.POOL_RADIUS, nsamples=config.NSAMPLE, mlps=mlps, use_xyz=True, pool_method='max_pool',
)
num_c_out = sum([x[-1] for x in mlps])
elif local_aggregation_name == 'VectorPoolAggregationModuleMSG':
cur_layer = VectorPoolAggregationModuleMSG(input_channels=input_channels, config=config)
num_c_out = config.MSG_POST_MLPS[-1]
else:
raise NotImplementedError
return cur_layer, num_c_out
class StackSAModuleMSG(nn.Module):
def __init__(self, *, radii: List[float], nsamples: List[int], mlps: List[List[int]],
......@@ -135,3 +155,316 @@ class StackPointnetFPModule(nn.Module):
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features
class VectorPoolLocalInterpolateModule(nn.Module):
def __init__(self, mlp, num_voxels, max_neighbour_distance, nsample, neighbor_type, use_xyz=True,
neighbour_distance_multiplier=1.0, xyz_encoding_type='concat'):
"""
Args:
mlp:
num_voxels:
max_neighbour_distance:
neighbor_type: 1: ball, others: cube
nsample: find all (-1), find limited number(>0)
use_xyz:
neighbour_distance_multiplier:
xyz_encoding_type:
"""
super().__init__()
self.num_voxels = num_voxels # [num_grid_x, num_grid_y, num_grid_z]: number of grids in each local area centered at new_xyz
self.num_total_grids = self.num_voxels[0] * self.num_voxels[1] * self.num_voxels[2]
self.max_neighbour_distance = max_neighbour_distance
self.neighbor_distance_multiplier = neighbour_distance_multiplier
self.nsample = nsample
self.neighbor_type = neighbor_type
self.use_xyz = use_xyz
self.xyz_encoding_type = xyz_encoding_type
if mlp is not None:
if self.use_xyz:
mlp[0] += 9 if self.xyz_encoding_type == 'concat' else 0
shared_mlps = []
for k in range(len(mlp) - 1):
shared_mlps.extend([
nn.Conv2d(mlp[k], mlp[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp[k + 1]),
nn.ReLU()
])
self.mlp = nn.Sequential(*shared_mlps)
else:
self.mlp = None
self.num_avg_length_of_neighbor_idxs = 1000
def forward(self, support_xyz, support_features, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt):
"""
Args:
support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
support_features: (N1 + N2 ..., C) point-wise features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
new_xyz: (M1 + M2 ..., 3) centers of the ball query
new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
Returns:
new_features: (N1 + N2 ..., C_out)
"""
with torch.no_grad():
dist, idx, num_avg_length_of_neighbor_idxs = pointnet2_utils.three_nn_for_vector_pool_by_two_step(
support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt,
self.max_neighbour_distance, self.nsample, self.neighbor_type,
self.num_avg_length_of_neighbor_idxs, self.num_total_grids, self.neighbor_distance_multiplier
)
self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs.item())
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=-1, keepdim=True)
weight = dist_recip / torch.clamp_min(norm, min=1e-8)
empty_mask = (idx.view(-1, 3)[:, 0] == -1)
idx.view(-1, 3)[empty_mask] = 0
interpolated_feats = pointnet2_utils.three_interpolate(support_features, idx.view(-1, 3), weight.view(-1, 3))
interpolated_feats = interpolated_feats.view(idx.shape[0], idx.shape[1], -1) # (M1 + M2 ..., num_total_grids, C)
if self.use_xyz:
near_known_xyz = support_xyz[idx.view(-1, 3).long()].view(-1, 3, 3) # ( (M1 + M2 ...)*num_total_grids, 3)
local_xyz = (new_xyz_grid_centers.view(-1, 1, 3) - near_known_xyz).view(-1, idx.shape[1], 9)
if self.xyz_encoding_type == 'concat':
interpolated_feats = torch.cat((interpolated_feats, local_xyz), dim=-1) # ( M1 + M2 ..., num_total_grids, 9+C)
else:
raise NotImplementedError
new_features = interpolated_feats.view(-1, interpolated_feats.shape[-1]) # ((M1 + M2 ...) * num_total_grids, C)
new_features[empty_mask, :] = 0
if self.mlp is not None:
new_features = new_features.permute(1, 0)[None, :, :, None] # (1, C, N1 + N2 ..., 1)
new_features = self.mlp(new_features)
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features
class VectorPoolAggregationModule(nn.Module):
def __init__(
self, input_channels, num_local_voxel=(3, 3, 3), local_aggregation_type='local_interpolation',
num_reduced_channels=30, num_channels_of_local_aggregation=32, post_mlps=(128,),
max_neighbor_distance=None, neighbor_nsample=-1, neighbor_type=0, neighbor_distance_multiplier=2.0):
super().__init__()
self.num_local_voxel = num_local_voxel
self.total_voxels = self.num_local_voxel[0] * self.num_local_voxel[1] * self.num_local_voxel[2]
self.local_aggregation_type = local_aggregation_type
assert self.local_aggregation_type in ['local_interpolation', 'voxel_avg_pool', 'voxel_random_choice']
self.input_channels = input_channels
self.num_reduced_channels = input_channels if num_reduced_channels is None else num_reduced_channels
self.num_channels_of_local_aggregation = num_channels_of_local_aggregation
self.max_neighbour_distance = max_neighbor_distance
self.neighbor_nsample = neighbor_nsample
self.neighbor_type = neighbor_type # 1: ball, others: cube
if self.local_aggregation_type == 'local_interpolation':
self.local_interpolate_module = VectorPoolLocalInterpolateModule(
mlp=None, num_voxels=self.num_local_voxel,
max_neighbour_distance=self.max_neighbour_distance,
nsample=self.neighbor_nsample,
neighbor_type=self.neighbor_type,
neighbour_distance_multiplier=neighbor_distance_multiplier,
)
num_c_in = (self.num_reduced_channels + 9) * self.total_voxels
else:
self.local_interpolate_module = None
num_c_in = (self.num_reduced_channels + 3) * self.total_voxels
num_c_out = self.total_voxels * self.num_channels_of_local_aggregation
self.separate_local_aggregation_layer = nn.Sequential(
nn.Conv1d(num_c_in, num_c_out, kernel_size=1, groups=self.total_voxels, bias=False),
nn.BatchNorm1d(num_c_out),
nn.ReLU()
)
post_mlp_list = []
c_in = num_c_out
for cur_num_c in post_mlps:
post_mlp_list.extend([
nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False),
nn.BatchNorm1d(cur_num_c),
nn.ReLU()
])
c_in = cur_num_c
self.post_mlps = nn.Sequential(*post_mlp_list)
self.num_mean_points_per_grid = 20
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def extra_repr(self) -> str:
ret = f'radius={self.max_neighbour_distance}, local_voxels=({self.num_local_voxel}, ' \
f'local_aggregation_type={self.local_aggregation_type}, ' \
f'num_c_reduction={self.input_channels}->{self.num_reduced_channels}, ' \
f'num_c_local_aggregation={self.num_channels_of_local_aggregation}'
return ret
def vector_pool_with_voxel_query(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt):
use_xyz = 1
pooling_type = 0 if self.local_aggregation_type == 'voxel_avg_pool' else 1
new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid = pointnet2_utils.vector_pool_with_voxel_query_op(
xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt,
self.num_local_voxel[0], self.num_local_voxel[1], self.num_local_voxel[2],
self.max_neighbour_distance, self.num_reduced_channels, use_xyz,
self.num_mean_points_per_grid, self.neighbor_nsample, self.neighbor_type,
pooling_type
)
self.num_mean_points_per_grid = max(self.num_mean_points_per_grid, num_mean_points_per_grid.item())
num_new_pts = new_features.shape[0]
new_local_xyz = new_local_xyz.view(num_new_pts, -1, 3) # (N, num_voxel, 3)
new_features = new_features.view(num_new_pts, -1, self.num_reduced_channels) # (N, num_voxel, C)
new_features = torch.cat((new_local_xyz, new_features), dim=-1).view(num_new_pts, -1)
return new_features, point_cnt_of_grid
@staticmethod
def get_dense_voxels_by_center(point_centers, max_neighbour_distance, num_voxels):
"""
Args:
point_centers: (N, 3)
max_neighbour_distance: float
num_voxels: [num_x, num_y, num_z]
Returns:
voxel_centers: (N, total_voxels, 3)
"""
R = max_neighbour_distance
device = point_centers.device
x_grids = torch.arange(-R + R / num_voxels[0], R - R / num_voxels[0] + 1e-5, 2 * R / num_voxels[0], device=device)
y_grids = torch.arange(-R + R / num_voxels[1], R - R / num_voxels[1] + 1e-5, 2 * R / num_voxels[1], device=device)
z_grids = torch.arange(-R + R / num_voxels[2], R - R / num_voxels[2] + 1e-5, 2 * R / num_voxels[2], device=device)
x_offset, y_offset, z_offset = torch.meshgrid(x_grids, y_grids, z_grids) # shape: [num_x, num_y, num_z]
xyz_offset = torch.cat((
x_offset.contiguous().view(-1, 1),
y_offset.contiguous().view(-1, 1),
z_offset.contiguous().view(-1, 1)), dim=-1
)
voxel_centers = point_centers[:, None, :] + xyz_offset[None, :, :]
return voxel_centers
def vector_pool_with_local_interpolate(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt):
"""
Args:
xyz: (N, 3)
xyz_batch_cnt: (batch_size)
features: (N, C)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size)
Returns:
new_features: (M, total_voxels * C)
"""
voxel_centers = self.get_dense_voxels_by_center(
point_centers=new_xyz, max_neighbour_distance=self.max_neighbour_distance, num_voxels=self.num_local_voxel
) # (M1 + M2 + ..., total_voxels, 3)
voxel_features = self.local_interpolate_module.forward(
support_xyz=xyz, support_features=features, xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz, new_xyz_grid_centers=voxel_centers, new_xyz_batch_cnt=new_xyz_batch_cnt
) # ((M1 + M2 ...) * total_voxels, C)
voxel_features = voxel_features.contiguous().view(-1, self.total_voxels * voxel_features.shape[-1])
return voxel_features
def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features, **kwargs):
"""
:param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features
:param xyz_batch_cnt: (batch_size), [N1, N2, ...]
:param new_xyz: (M1 + M2 ..., 3)
:param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
:param features: (N1 + N2 ..., C) tensor of the descriptors of the the features
:return:
new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz
new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
N, C = features.shape
assert C % self.num_reduced_channels == 0, \
f'the input channels ({C}) should be an integral multiple of num_reduced_channels({self.num_reduced_channels})'
features = features.view(N, -1, self.num_reduced_channels).sum(dim=1)
if self.local_aggregation_type in ['voxel_avg_pool', 'voxel_random_choice']:
vector_features, point_cnt_of_grid = self.vector_pool_with_voxel_query(
xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features,
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
)
elif self.local_aggregation_type == 'local_interpolation':
vector_features = self.vector_pool_with_local_interpolate(
xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features,
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
) # (M1 + M2 + ..., total_voxels * C)
else:
raise NotImplementedError
vector_features = vector_features.permute(1, 0)[None, :, :] # (1, num_voxels * C, M1 + M2 ...)
new_features = self.separate_local_aggregation_layer(vector_features)
new_features = self.post_mlps(new_features)
new_features = new_features.squeeze(dim=0).permute(1, 0)
return new_xyz, new_features
class VectorPoolAggregationModuleMSG(nn.Module):
def __init__(self, input_channels, config):
super().__init__()
self.model_cfg = config
self.num_groups = self.model_cfg.NUM_GROUPS
self.layers = []
c_in = 0
for k in range(self.num_groups):
cur_config = self.model_cfg[f'GROUP_CFG_{k}']
cur_vector_pool_module = VectorPoolAggregationModule(
input_channels=input_channels, num_local_voxel=cur_config.NUM_LOCAL_VOXEL,
post_mlps=cur_config.POST_MLPS,
max_neighbor_distance=cur_config.MAX_NEIGHBOR_DISTANCE,
neighbor_nsample=cur_config.NEIGHBOR_NSAMPLE,
local_aggregation_type=self.model_cfg.LOCAL_AGGREGATION_TYPE,
num_reduced_channels=self.model_cfg.get('NUM_REDUCED_CHANNELS', None),
num_channels_of_local_aggregation=self.model_cfg.NUM_CHANNELS_OF_LOCAL_AGGREGATION,
neighbor_distance_multiplier=2.0
)
self.__setattr__(f'layer_{k}', cur_vector_pool_module)
c_in += cur_config.POST_MLPS[-1]
c_in += 3 # use_xyz
shared_mlps = []
for cur_num_c in self.model_cfg.MSG_POST_MLPS:
shared_mlps.extend([
nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False),
nn.BatchNorm1d(cur_num_c),
nn.ReLU()
])
c_in = cur_num_c
self.msg_post_mlps = nn.Sequential(*shared_mlps)
def forward(self, **kwargs):
features_list = []
for k in range(self.num_groups):
cur_xyz, cur_features = self.__getattr__(f'layer_{k}')(**kwargs)
features_list.append(cur_features)
features = torch.cat(features_list, dim=-1)
features = torch.cat((cur_xyz, features), dim=-1)
features = features.permute(1, 0)[None, :, :] # (1, C, N)
new_features = self.msg_post_mlps(features)
new_features = new_features.squeeze(dim=0).permute(1, 0) # (N, C)
return cur_xyz, new_features
......@@ -299,5 +299,154 @@ class ThreeInterpolate(Function):
three_interpolate = ThreeInterpolate.apply
class ThreeNNForVectorPoolByTwoStep(Function):
@staticmethod
def forward(ctx, support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt,
max_neighbour_distance, nsample, neighbor_type, avg_length_of_neighbor_idxs, num_total_grids,
neighbor_distance_multiplier):
"""
Args:
ctx:
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// nsample: find all (-1), find limited number(>0)
// neighbor_type: 1: ball, others: cube
// neighbor_distance_multiplier: query_distance = neighbor_distance_multiplier * max_neighbour_distance
Returns:
// new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn
// new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn
"""
num_new_xyz = new_xyz.shape[0]
new_xyz_grid_dist2 = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape)
new_xyz_grid_idxs = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape).int().fill_(-1)
while True:
num_max_sum_points = avg_length_of_neighbor_idxs * num_new_xyz
stack_neighbor_idxs = new_xyz_grid_idxs.new_zeros(num_max_sum_points)
start_len = new_xyz_grid_idxs.new_zeros(num_new_xyz, 2).int()
cumsum = new_xyz_grid_idxs.new_zeros(1)
pointnet2.query_stacked_local_neighbor_idxs_wrapper_stack(
support_xyz.contiguous(), xyz_batch_cnt.contiguous(),
new_xyz.contiguous(), new_xyz_batch_cnt.contiguous(),
stack_neighbor_idxs.contiguous(), start_len.contiguous(), cumsum,
avg_length_of_neighbor_idxs, max_neighbour_distance * neighbor_distance_multiplier,
nsample, neighbor_type
)
avg_length_of_neighbor_idxs = cumsum[0] // num_new_xyz + int(cumsum[0] % num_new_xyz > 0)
if cumsum[0] <= num_max_sum_points:
break
stack_neighbor_idxs = stack_neighbor_idxs[:cumsum[0]]
pointnet2.query_three_nn_by_stacked_local_idxs_wrapper_stack(
support_xyz, new_xyz, new_xyz_grid_centers, new_xyz_grid_idxs, new_xyz_grid_dist2,
stack_neighbor_idxs, start_len, num_new_xyz, num_total_grids
)
return torch.sqrt(new_xyz_grid_dist2), new_xyz_grid_idxs, avg_length_of_neighbor_idxs
three_nn_for_vector_pool_by_two_step = ThreeNNForVectorPoolByTwoStep.apply
class VectorPoolWithVoxelQuery(Function):
@staticmethod
def forward(ctx, support_xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor, support_features: torch.Tensor,
new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor, num_grid_x, num_grid_y, num_grid_z,
max_neighbour_distance, num_c_out_each_grid, use_xyz,
num_mean_points_per_grid=100, nsample=-1, neighbor_type=0, pooling_type=0):
"""
Args:
ctx:
support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
support_features: (N1 + N2 ..., C)
new_xyz: (M1 + M2 ..., 3) centers of new positions
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
num_grid_x: number of grids in each local area centered at new_xyz
num_grid_y:
num_grid_z:
max_neighbour_distance:
num_c_out_each_grid:
use_xyz:
neighbor_type: 1: ball, others: cube:
pooling_type: 0: avg_pool, 1: random choice
Returns:
new_features: (M1 + M2 ..., num_c_out)
"""
assert support_xyz.is_contiguous()
assert support_features.is_contiguous()
assert xyz_batch_cnt.is_contiguous()
assert new_xyz.is_contiguous()
assert new_xyz_batch_cnt.is_contiguous()
num_total_grids = num_grid_x * num_grid_y * num_grid_z
num_c_out = num_c_out_each_grid * num_total_grids
N, num_c_in = support_features.shape
M = new_xyz.shape[0]
assert num_c_in % num_c_out_each_grid == 0, \
f'the input channels ({num_c_in}) should be an integral multiple of num_c_out_each_grid({num_c_out_each_grid})'
while True:
new_features = support_features.new_zeros((M, num_c_out))
new_local_xyz = support_features.new_zeros((M, 3 * num_total_grids))
point_cnt_of_grid = xyz_batch_cnt.new_zeros((M, num_total_grids))
num_max_sum_points = num_mean_points_per_grid * M
grouped_idxs = xyz_batch_cnt.new_zeros((num_max_sum_points, 3))
num_cum_sum = pointnet2.vector_pool_wrapper(
support_xyz, xyz_batch_cnt, support_features, new_xyz, new_xyz_batch_cnt,
new_features, new_local_xyz, point_cnt_of_grid, grouped_idxs,
num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance, use_xyz,
num_max_sum_points, nsample, neighbor_type, pooling_type
)
num_mean_points_per_grid = num_cum_sum // M + int(num_cum_sum % M > 0)
if num_cum_sum <= num_max_sum_points:
break
grouped_idxs = grouped_idxs[:num_cum_sum]
normalizer = torch.clamp_min(point_cnt_of_grid[:, :, None].float(), min=1e-6)
new_features = (new_features.view(-1, num_total_grids, num_c_out_each_grid) / normalizer).view(-1, num_c_out)
if use_xyz:
new_local_xyz = (new_local_xyz.view(-1, num_total_grids, 3) / normalizer).view(-1, num_total_grids * 3)
num_mean_points_per_grid = torch.Tensor([num_mean_points_per_grid]).int()
nsample = torch.Tensor([nsample]).int()
ctx.vector_pool_for_backward = (point_cnt_of_grid, grouped_idxs, N, num_c_in)
ctx.mark_non_differentiable(new_local_xyz, num_mean_points_per_grid, nsample, point_cnt_of_grid)
return new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid
@staticmethod
def backward(ctx, grad_new_features: torch.Tensor, grad_local_xyz: torch.Tensor, grad_num_cum_sum, grad_point_cnt_of_grid):
"""
Args:
ctx:
grad_new_features: (M1 + M2 ..., num_c_out), num_c_out = num_c_out_each_grid * num_total_grids
Returns:
grad_support_features: (N1 + N2 ..., C_in)
"""
point_cnt_of_grid, grouped_idxs, N, num_c_in = ctx.vector_pool_for_backward
grad_support_features = grad_new_features.new_zeros((N, num_c_in))
pointnet2.vector_pool_grad_wrapper(
grad_new_features.contiguous(), point_cnt_of_grid, grouped_idxs,
grad_support_features
)
return None, None, grad_support_features, None, None, None, None, None, None, None, None, None, None, None, None
vector_pool_with_voxel_query_op = VectorPoolWithVoxelQuery.apply
if __name__ == '__main__':
pass
......@@ -6,6 +6,7 @@
#include "sampling_gpu.h"
#include "interpolate_gpu.h"
#include "voxel_query_gpu.h"
#include "vector_pool_gpu.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......@@ -21,4 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack");
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack");
m.def("query_stacked_local_neighbor_idxs_wrapper_stack", &query_stacked_local_neighbor_idxs_wrapper_stack, "query_stacked_local_neighbor_idxs_wrapper_stack");
m.def("query_three_nn_by_stacked_local_idxs_wrapper_stack", &query_three_nn_by_stacked_local_idxs_wrapper_stack, "query_three_nn_by_stacked_local_idxs_wrapper_stack");
m.def("vector_pool_wrapper", &vector_pool_wrapper_stack, "vector_pool_grad_wrapper_stack");
m.def("vector_pool_grad_wrapper", &vector_pool_grad_wrapper_stack, "vector_pool_grad_wrapper_stack");
}
/*
Vector-pool aggregation based local feature aggregation for point cloud.
PV-RCNN++: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection
https://arxiv.org/abs/2102.00463
Written by Shaoshuai Shi
All Rights Reserved 2020.
*/
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "vector_pool_gpu.h"
extern THCState *state;
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int query_stacked_local_neighbor_idxs_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor,
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cumsum_tensor,
int avg_length_of_neighbor_idxs, float max_neighbour_distance, int nsample, int neighbor_type){
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn
// new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn
// num_grid_x, num_grid_y, num_grid_z: number of grids in each local area centered at new_xyz
// nsample: find all (-1), find limited number(>0)
// neighbor_type: 1: ball, others: cube
CHECK_INPUT(support_xyz_tensor);
CHECK_INPUT(xyz_batch_cnt_tensor);
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(new_xyz_batch_cnt_tensor);
CHECK_INPUT(stack_neighbor_idxs_tensor);
CHECK_INPUT(start_len_tensor);
CHECK_INPUT(cumsum_tensor);
const float *support_xyz = support_xyz_tensor.data<float>();
const int *xyz_batch_cnt = xyz_batch_cnt_tensor.data<int>();
const float *new_xyz = new_xyz_tensor.data<float>();
const int *new_xyz_batch_cnt = new_xyz_batch_cnt_tensor.data<int>();
int *stack_neighbor_idxs = stack_neighbor_idxs_tensor.data<int>();
int *start_len = start_len_tensor.data<int>();
int *cumsum = cumsum_tensor.data<int>();
int batch_size = xyz_batch_cnt_tensor.size(0);
int M = new_xyz_tensor.size(0);
query_stacked_local_neighbor_idxs_kernel_launcher_stack(
support_xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt,
stack_neighbor_idxs, start_len, cumsum, avg_length_of_neighbor_idxs,
max_neighbour_distance, batch_size, M, nsample, neighbor_type
);
return 0;
}
int query_three_nn_by_stacked_local_idxs_wrapper_stack(at::Tensor support_xyz_tensor,
at::Tensor new_xyz_tensor, at::Tensor new_xyz_grid_centers_tensor,
at::Tensor new_xyz_grid_idxs_tensor, at::Tensor new_xyz_grid_dist2_tensor,
at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor,
int M, int num_total_grids){
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
// new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn
// new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn
// stack_neighbor_idxs: (max_length_of_neighbor_idxs)
// start_len: (M1 + M2, 2) [start_offset, neighbor_length]
CHECK_INPUT(support_xyz_tensor);
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(new_xyz_grid_centers_tensor);
CHECK_INPUT(new_xyz_grid_idxs_tensor);
CHECK_INPUT(new_xyz_grid_dist2_tensor);
CHECK_INPUT(stack_neighbor_idxs_tensor);
CHECK_INPUT(start_len_tensor);
const float *support_xyz = support_xyz_tensor.data<float>();
const float *new_xyz = new_xyz_tensor.data<float>();
const float *new_xyz_grid_centers = new_xyz_grid_centers_tensor.data<float>();
int *new_xyz_grid_idxs = new_xyz_grid_idxs_tensor.data<int>();
float *new_xyz_grid_dist2 = new_xyz_grid_dist2_tensor.data<float>();
int *stack_neighbor_idxs = stack_neighbor_idxs_tensor.data<int>();
int *start_len = start_len_tensor.data<int>();
query_three_nn_by_stacked_local_idxs_kernel_launcher_stack(
support_xyz, new_xyz, new_xyz_grid_centers,
new_xyz_grid_idxs, new_xyz_grid_dist2, stack_neighbor_idxs, start_len,
M, num_total_grids
);
return 0;
}
int vector_pool_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor,
at::Tensor support_features_tensor, at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
at::Tensor new_features_tensor, at::Tensor new_local_xyz_tensor,
at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor,
int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, int use_xyz,
int num_max_sum_points, int nsample, int neighbor_type, int pooling_type){
// support_xyz_tensor: (N1 + N2 ..., 3) xyz coordinates of the features
// support_features_tensor: (N1 + N2 ..., C)
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz_tensor: (M1 + M2 ..., 3) centers of new positions
// new_features_tensor: (M1 + M2 ..., C)
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// point_cnt_of_grid: (M1 + M2 ..., num_total_grids)
// grouped_idxs_tensor: (num_max_sum_points, 3)
// num_grid_x, num_grid_y, num_grid_z: number of grids in each local area centered at new_xyz
// use_xyz: whether to calculate new_local_xyz
// neighbor_type: 1: ball, others: cube
// pooling_type: 0: avg_pool, 1: random choice
CHECK_INPUT(support_xyz_tensor);
CHECK_INPUT(support_features_tensor);
CHECK_INPUT(xyz_batch_cnt_tensor);
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(new_xyz_batch_cnt_tensor);
CHECK_INPUT(new_features_tensor);
CHECK_INPUT(new_local_xyz_tensor);
CHECK_INPUT(point_cnt_of_grid_tensor);
CHECK_INPUT(grouped_idxs_tensor);
const float *support_xyz = support_xyz_tensor.data<float>();
const float *support_features = support_features_tensor.data<float>();
const int *xyz_batch_cnt = xyz_batch_cnt_tensor.data<int>();
const float *new_xyz = new_xyz_tensor.data<float>();
const int *new_xyz_batch_cnt = new_xyz_batch_cnt_tensor.data<int>();
float *new_features = new_features_tensor.data<float>();
float *new_local_xyz = new_local_xyz_tensor.data<float>();
int *point_cnt_of_grid = point_cnt_of_grid_tensor.data<int>();
int *grouped_idxs = grouped_idxs_tensor.data<int>();
int N = support_xyz_tensor.size(0);
int batch_size = xyz_batch_cnt_tensor.size(0);
int M = new_xyz_tensor.size(0);
int num_c_out = new_features_tensor.size(1);
int num_c_in = support_features_tensor.size(1);
int num_total_grids = point_cnt_of_grid_tensor.size(1);
int cum_sum = vector_pool_kernel_launcher_stack(
support_xyz, support_features, xyz_batch_cnt,
new_xyz, new_features, new_local_xyz, new_xyz_batch_cnt,
point_cnt_of_grid, grouped_idxs,
num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance,
batch_size, N, M, num_c_in, num_c_out, num_total_grids, use_xyz, num_max_sum_points, nsample, neighbor_type, pooling_type
);
return cum_sum;
}
int vector_pool_grad_wrapper_stack(at::Tensor grad_new_features_tensor,
at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor,
at::Tensor grad_support_features_tensor) {
// grad_new_features_tensor: (M1 + M2 ..., C_out)
// point_cnt_of_grid_tensor: (M1 + M2 ..., num_total_grids)
// grouped_idxs_tensor: (num_max_sum_points, 3) [idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz]
// grad_support_features_tensor: (N1 + N2 ..., C_in)
CHECK_INPUT(grad_new_features_tensor);
CHECK_INPUT(point_cnt_of_grid_tensor);
CHECK_INPUT(grouped_idxs_tensor);
CHECK_INPUT(grad_support_features_tensor);
int M = grad_new_features_tensor.size(0);
int num_c_out = grad_new_features_tensor.size(1);
int N = grad_support_features_tensor.size(0);
int num_c_in = grad_support_features_tensor.size(1);
int num_total_grids = point_cnt_of_grid_tensor.size(1);
int num_max_sum_points = grouped_idxs_tensor.size(0);
const float *grad_new_features = grad_new_features_tensor.data<float>();
const int *point_cnt_of_grid = point_cnt_of_grid_tensor.data<int>();
const int *grouped_idxs = grouped_idxs_tensor.data<int>();
float *grad_support_features = grad_support_features_tensor.data<float>();
vector_pool_grad_kernel_launcher_stack(
grad_new_features, point_cnt_of_grid, grouped_idxs, grad_support_features,
N, M, num_c_out, num_c_in, num_total_grids, num_max_sum_points
);
return 1;
}
/*
Vector-pool aggregation based local feature aggregation for point cloud.
PV-RCNN++: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection
https://arxiv.org/abs/2102.00463
Written by Shaoshuai Shi
All Rights Reserved 2020.
*/
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "vector_pool_gpu.h"
#include "cuda_utils.h"
__global__ void query_three_nn_by_stacked_local_idxs_kernel(
const float *support_xyz, const float *new_xyz, const float *new_xyz_grid_centers,
int *new_xyz_grid_idxs, float *new_xyz_grid_dist2,
const int *stack_neighbor_idxs, const int *start_len,
int M, int num_total_grids){
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
// new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn
// new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn
// stack_neighbor_idxs: (max_length_of_neighbor_idxs)
// start_len: (M1 + M2, 2) [start_offset, neighbor_length]
int grid_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (pt_idx >= M || grid_idx >= num_total_grids) return;
new_xyz += pt_idx * 3;
new_xyz_grid_centers += pt_idx * num_total_grids * 3 + grid_idx * 3;
new_xyz_grid_idxs += pt_idx * num_total_grids * 3 + grid_idx * 3;
new_xyz_grid_dist2 += pt_idx * num_total_grids * 3 + grid_idx * 3;
start_len += pt_idx * 2;
stack_neighbor_idxs += start_len[0];
int neighbor_length = start_len[1];
float center_x = new_xyz_grid_centers[0];
float center_y = new_xyz_grid_centers[1];
float center_z = new_xyz_grid_centers[2];
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
int besti1 = -1, besti2 = -1, besti3 = -1;
for (int k = 0; k < neighbor_length; k++){
int cur_neighbor_idx = stack_neighbor_idxs[k];
float x = support_xyz[cur_neighbor_idx * 3 + 0];
float y = support_xyz[cur_neighbor_idx * 3 + 1];
float z = support_xyz[cur_neighbor_idx * 3 + 2];
float d = (center_x - x) * (center_x - x) + (center_y - y) * (center_y - y) + (center_z - z) * (center_z - z);
if (d < best1) {
best3 = best2; besti3 = besti2;
best2 = best1; besti2 = besti1;
best1 = d; besti1 = cur_neighbor_idx;
}
else if (d < best2) {
best3 = best2; besti3 = besti2;
best2 = d; besti2 = cur_neighbor_idx;
}
else if (d < best3) {
best3 = d; besti3 = cur_neighbor_idx;
}
}
if (besti2 == -1){
besti2 = besti1; best2 = best1;
}
if (besti3 == -1){
besti3 = besti1; best3 = best1;
}
new_xyz_grid_dist2[0] = best1;
new_xyz_grid_dist2[1] = best2;
new_xyz_grid_dist2[2] = best3;
new_xyz_grid_idxs[0] = besti1;
new_xyz_grid_idxs[1] = besti2;
new_xyz_grid_idxs[2] = besti3;
}
int query_three_nn_by_stacked_local_idxs_kernel_launcher_stack(
const float *support_xyz, const float *new_xyz, const float *new_xyz_grid_centers,
int *new_xyz_grid_idxs, float *new_xyz_grid_dist2,
const int *stack_neighbor_idxs, const int *start_len,
int M, int num_total_grids){
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
// new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn
// new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn
// stack_neighbor_idxs: (max_length_of_neighbor_idxs)
// start_len: (M1 + M2, 2) [start_offset, neighbor_length]
cudaError_t err;
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), num_total_grids); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
query_three_nn_by_stacked_local_idxs_kernel<<<blocks, threads>>>(
support_xyz, new_xyz, new_xyz_grid_centers,
new_xyz_grid_idxs, new_xyz_grid_dist2, stack_neighbor_idxs, start_len,
M, num_total_grids
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 0;
}
__global__ void query_stacked_local_neighbor_idxs_kernel(
const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, const int *new_xyz_batch_cnt,
int *stack_neighbor_idxs, int *start_len, int *cumsum, int avg_length_of_neighbor_idxs,
float max_neighbour_distance, int batch_size, int M, int nsample, int neighbor_type){
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// stack_neighbor_idxs: (max_length_of_neighbor_idxs)
// start_len: (M1 + M2, 2) [start_offset, neighbor_length]
// cumsum: (1), max offset of current data in stack_neighbor_idxs
// max_neighbour_distance: float
// nsample: find all (-1), find limited number(>0)
// neighbor_type: 1: ball, others: cube
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (pt_idx >= M) return;
int bs_idx = 0, pt_cnt = new_xyz_batch_cnt[0];
for (int k = 1; k < batch_size; k++){
if (pt_idx < pt_cnt) break;
pt_cnt += new_xyz_batch_cnt[k];
bs_idx = k;
}
int xyz_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k];
support_xyz += xyz_batch_start_idx * 3;
new_xyz += pt_idx * 3;
start_len += pt_idx * 2;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
int n = xyz_batch_cnt[bs_idx];
float local_x, local_y, local_z;
float radius2 = max_neighbour_distance * max_neighbour_distance;
int temp_idxs[1000];
int sample_cnt = 0;
for (int k = 0; k < n; ++k) {
local_x = support_xyz[k * 3 + 0] - new_x;
local_y = support_xyz[k * 3 + 1] - new_y;
local_z = support_xyz[k * 3 + 2] - new_z;
if (neighbor_type == 1){
// ball
if (local_x * local_x + local_y * local_y + local_z * local_z > radius2){
continue;
}
}
else{
// voxel
if ((fabs(local_x) > max_neighbour_distance) |
(fabs(local_y) > max_neighbour_distance) |
(fabs(local_z) > max_neighbour_distance)){
continue;
}
}
if (sample_cnt < 1000){
temp_idxs[sample_cnt] = k;
}
else{
break;
}
sample_cnt++;
if (nsample > 0 && sample_cnt >= nsample) break;
}
start_len[0] = atomicAdd(cumsum, sample_cnt);
start_len[1] = sample_cnt;
int max_thresh = avg_length_of_neighbor_idxs * M;
if (start_len[0] >= max_thresh) return;
stack_neighbor_idxs += start_len[0];
if (start_len[0] + sample_cnt >= max_thresh) sample_cnt = max_thresh - start_len[0];
for (int k = 0; k < sample_cnt; k++){
stack_neighbor_idxs[k] = temp_idxs[k] + xyz_batch_start_idx;
}
}
int query_stacked_local_neighbor_idxs_kernel_launcher_stack(
const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, const int *new_xyz_batch_cnt,
int *stack_neighbor_idxs, int *start_len, int *cumsum, int avg_length_of_neighbor_idxs,
float max_neighbour_distance, int batch_size, int M, int nsample, int neighbor_type){
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// stack_neighbor_idxs: (max_length_of_neighbor_idxs)
// start_len: (M1 + M2, 2) [start_offset, neighbor_length]
// cumsum: (1), max offset of current data in stack_neighbor_idxs
// max_neighbour_distance: float
// nsample: find all (-1), find limited number(>0)
// neighbor_type: 1: ball, others: cube
cudaError_t err;
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
query_stacked_local_neighbor_idxs_kernel<<<blocks, threads>>>(
support_xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt,
stack_neighbor_idxs, start_len, cumsum, avg_length_of_neighbor_idxs,
max_neighbour_distance, batch_size, M, nsample, neighbor_type
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 0;
}
__global__ void vector_pool_kernel_stack(
const float *support_xyz, const float *support_features, const int *xyz_batch_cnt,
const float *new_xyz, float *new_features, float *new_local_xyz, const int *new_xyz_batch_cnt,
int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance,
int batch_size, int M, int num_c_in, int num_c_out,
int num_c_each_grid, int num_total_grids, int *point_cnt_of_grid, int *grouped_idxs,
int use_xyz, float grid_size_x, float grid_size_y,
float grid_size_z, int *cum_sum, int num_max_sum_points, int nsample, int neighbor_type, int pooling_type){
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// support_features: (N1 + N2 ..., C)
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_features: (M1 + M2 ..., C), C = num_total_grids * num_c_each_grid
// new_local_xyz: (M1 + M2 ..., 3 * num_total_grids)
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// num_grid_x, num_grid_y, num_grid_z: number of grids in each local area centered at new_xyz
// point_cnt_of_grid: (M1 + M2 ..., num_total_grids)
// grouped_idxs: (num_max_sum_points, 3)[idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz]
// use_xyz: whether to calculate new_local_xyz
// neighbor_type: 1: ball, others: cube
// pooling_type: 0: avg_pool, 1: random choice
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (pt_idx >= M) return;
int bs_idx = 0, pt_cnt = new_xyz_batch_cnt[0];
for (int k = 1; k < batch_size; k++){
if (pt_idx < pt_cnt) break;
pt_cnt += new_xyz_batch_cnt[k];
bs_idx = k;
}
int xyz_batch_start_idx = 0;
for (int k = 0; k < bs_idx; k++) xyz_batch_start_idx += xyz_batch_cnt[k];
support_xyz += xyz_batch_start_idx * 3;
support_features += xyz_batch_start_idx * num_c_in;
new_xyz += pt_idx * 3;
new_features += pt_idx * num_c_out;
point_cnt_of_grid += pt_idx * num_total_grids;
new_local_xyz += pt_idx * 3 * num_total_grids;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
int n = xyz_batch_cnt[bs_idx], grid_idx_x, grid_idx_y, grid_idx_z, grid_idx;
float local_x, local_y, local_z;
float radius2 = max_neighbour_distance * max_neighbour_distance;
int sample_cnt = 0;
for (int k = 0; k < n; ++k) {
local_x = support_xyz[k * 3 + 0] - new_x;
local_y = support_xyz[k * 3 + 1] - new_y;
local_z = support_xyz[k * 3 + 2] - new_z;
if (neighbor_type == 1){
// ball
if (local_x * local_x + local_y * local_y + local_z * local_z > radius2){
continue;
}
}
else{
// voxel
if ((fabs(local_x) > max_neighbour_distance) |
(fabs(local_y) > max_neighbour_distance) |
(fabs(local_z) > max_neighbour_distance)){
continue;
}
}
grid_idx_x = floorf((local_x + max_neighbour_distance) / grid_size_x);
grid_idx_y = floorf((local_y + max_neighbour_distance) / grid_size_y);
grid_idx_z = floorf((local_z + max_neighbour_distance) / grid_size_z);
grid_idx = grid_idx_x * num_grid_y * num_grid_z + grid_idx_y * num_grid_z + grid_idx_z;
grid_idx = min(max(grid_idx, 0), num_total_grids - 1);
if (pooling_type == 0){
// avg pooling
point_cnt_of_grid[grid_idx] ++;
for (int i = 0; i < num_c_in; i++){
new_features[grid_idx * num_c_each_grid + i % num_c_each_grid] += support_features[k * num_c_in + i];
}
if (use_xyz){
new_local_xyz[grid_idx * 3 + 0] += local_x;
new_local_xyz[grid_idx * 3 + 1] += local_y;
new_local_xyz[grid_idx * 3 + 2] += local_z;
}
int cnt = atomicAdd(cum_sum, 1);
if (cnt >= num_max_sum_points) continue; // continue to statistics the max number of points
grouped_idxs[cnt * 3 + 0] = xyz_batch_start_idx + k;
grouped_idxs[cnt * 3 + 1] = pt_idx;
grouped_idxs[cnt * 3 + 2] = grid_idx;
sample_cnt++;
if(nsample > 0 && sample_cnt >= nsample) break;
}
else if (pooling_type == 1){
// random choose one within sub-voxel
// printf("new_xyz=(%.2f, %.2f, %.2f, ), find neighbor k=%d: support_xyz=(%.2f, %.2f, %.2f), local_xyz=(%.2f, %.2f, %.2f), neighbor=%.2f, grid_idx=%d, point_cnt_of_grid_idx=%d\n",
// new_x, new_y, new_z, k, support_xyz[k * 3 + 0], support_xyz[k * 3 + 1], support_xyz[k * 3 + 2], local_x, local_y, local_z, max_neighbour_distance, grid_idx, point_cnt_of_grid[grid_idx]);
if (point_cnt_of_grid[grid_idx] == 0){
point_cnt_of_grid[grid_idx] ++;
for (int i = 0; i < num_c_in; i++){
new_features[grid_idx * num_c_each_grid + i % num_c_each_grid] = support_features[k * num_c_in + i];
}
if (use_xyz){
new_local_xyz[grid_idx * 3 + 0] = local_x;
new_local_xyz[grid_idx * 3 + 1] = local_y;
new_local_xyz[grid_idx * 3 + 2] = local_z;
}
int cnt = atomicAdd(cum_sum, 1);
if (cnt >= num_max_sum_points) continue; // continue to statistics the max number of points
grouped_idxs[cnt * 3 + 0] = xyz_batch_start_idx + k;
grouped_idxs[cnt * 3 + 1] = pt_idx;
grouped_idxs[cnt * 3 + 2] = grid_idx;
sample_cnt++;
if(nsample > 0 && sample_cnt >= nsample || sample_cnt >= num_total_grids) break;
}
}
}
}
int vector_pool_kernel_launcher_stack(
const float *support_xyz, const float *support_features, const int *xyz_batch_cnt,
const float *new_xyz, float *new_features, float *new_local_xyz, const int *new_xyz_batch_cnt,
int *point_cnt_of_grid, int *grouped_idxs,
int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance,
int batch_size, int N, int M, int num_c_in, int num_c_out, int num_total_grids,
int use_xyz, int num_max_sum_points, int nsample, int neighbor_type, int pooling_type){
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// support_features: (N1 + N2 ..., C)
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_features: (M1 + M2 ..., C)
// new_local_xyz: (M1 + M2 ..., 3)
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// num_grid_x, num_grid_y, num_grid_z: number of grids in each local area centered at new_xyz
// use_xyz: whether to calculate new_local_xyz
// grouped_idxs: (num_max_sum_points, 3)[idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz]
// neighbor_type: 1: ball, others: cube
// pooling_type: 0: avg_pool, 1: random choice
cudaError_t err;
int num_c_each_grid = num_c_out / num_total_grids;
float grid_size_x = max_neighbour_distance * 2 / num_grid_x;
float grid_size_y = max_neighbour_distance * 2 / num_grid_y;
float grid_size_z = max_neighbour_distance * 2 / num_grid_z;
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
int cum_sum = 0;
int *p_cum_sum;
cudaMalloc((void**)&p_cum_sum, sizeof(int));
cudaMemcpy(p_cum_sum, &cum_sum, sizeof(int), cudaMemcpyHostToDevice);
vector_pool_kernel_stack<<<blocks, threads>>>(
support_xyz, support_features, xyz_batch_cnt,
new_xyz, new_features, new_local_xyz, new_xyz_batch_cnt,
num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance,
batch_size, M, num_c_in, num_c_out,
num_c_each_grid, num_total_grids, point_cnt_of_grid, grouped_idxs,
use_xyz, grid_size_x, grid_size_y, grid_size_z, p_cum_sum, num_max_sum_points,
nsample, neighbor_type, pooling_type
);
cudaMemcpy(&cum_sum, p_cum_sum, sizeof(int), cudaMemcpyDeviceToHost);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return cum_sum;
}
__global__ void vector_pool_grad_kernel_stack(const float *grad_new_features,
const int *point_cnt_of_grid, const int *grouped_idxs,
float *grad_support_features, int N, int M, int num_c_out, int num_c_in,
int num_c_each_grid, int num_total_grids, int num_max_sum_points){
// grad_new_features: (M1 + M2 ..., C_out)
// point_cnt_of_grid: (M1 + M2 ..., num_total_grids)
// grouped_idxs: (num_max_sum_points, 3) [idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz]
// grad_support_features: (N1 + N2 ..., C_in)
int channel_idx = blockIdx.y;
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= num_max_sum_points || channel_idx >= num_c_in) return;
int idx_of_support_xyz = grouped_idxs[index * 3 + 0];
int idx_of_new_xyz = grouped_idxs[index * 3 + 1];
int idx_of_grid_idx = grouped_idxs[index * 3 + 2];
int num_total_pts = point_cnt_of_grid[idx_of_new_xyz * num_total_grids + idx_of_grid_idx];
grad_support_features += idx_of_support_xyz * num_c_in + channel_idx;
grad_new_features += idx_of_new_xyz * num_c_out + idx_of_grid_idx * num_c_each_grid;
int channel_idx_of_cin = channel_idx % num_c_each_grid;
float cur_grad = 1 / fmaxf(float(num_total_pts), 1.0);
atomicAdd(grad_support_features, grad_new_features[channel_idx_of_cin] * cur_grad);
}
void vector_pool_grad_kernel_launcher_stack(
const float *grad_new_features, const int *point_cnt_of_grid, const int *grouped_idxs,
float *grad_support_features, int N, int M, int num_c_out, int num_c_in, int num_total_grids,
int num_max_sum_points){
// grad_new_features: (M1 + M2 ..., C_out)
// point_cnt_of_grid: (M1 + M2 ..., num_total_grids)
// grouped_idxs: (num_max_sum_points, 3) [idx of support_xyz, idx of new_xyz, idx of grid_idx in new_xyz]
// grad_support_features: (N1 + N2 ..., C_in)
int num_c_each_grid = num_c_out / num_total_grids;
cudaError_t err;
dim3 blocks(DIVUP(num_max_sum_points, THREADS_PER_BLOCK), num_c_in); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
vector_pool_grad_kernel_stack<<<blocks, threads>>>(
grad_new_features, point_cnt_of_grid, grouped_idxs, grad_support_features,
N, M, num_c_out, num_c_in, num_c_each_grid, num_total_grids, num_max_sum_points
);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
\ No newline at end of file
/*
Vector-pool aggregation based local feature aggregation for point cloud.
PV-RCNN++: Point-Voxel Feature Set Abstraction With Local Vector Representation for 3D Object Detection
https://arxiv.org/abs/2102.00463
Written by Shaoshuai Shi
All Rights Reserved 2020.
*/
#ifndef _STACK_VECTOR_POOL_GPU_H
#define _STACK_VECTOR_POOL_GPU_H
#include <torch/serialize/tensor.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
int query_stacked_local_neighbor_idxs_kernel_launcher_stack(
const float *support_xyz, const int *xyz_batch_cnt, const float *new_xyz, const int *new_xyz_batch_cnt,
int *stack_neighbor_idxs, int *start_len, int *cumsum, int avg_length_of_neighbor_idxs,
float max_neighbour_distance, int batch_size, int M, int nsample, int neighbor_type);
int query_stacked_local_neighbor_idxs_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor,
at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cumsum_tensor,
int avg_length_of_neighbor_idxs, float max_neighbour_distance, int nsample, int neighbor_type);
int query_three_nn_by_stacked_local_idxs_kernel_launcher_stack(
const float *support_xyz, const float *new_xyz, const float *new_xyz_grid_centers,
int *new_xyz_grid_idxs, float *new_xyz_grid_dist2,
const int *stack_neighbor_idxs, const int *start_len,
int M, int num_total_grids);
int query_three_nn_by_stacked_local_idxs_wrapper_stack(at::Tensor support_xyz_tensor,
at::Tensor new_xyz_tensor, at::Tensor new_xyz_grid_centers_tensor,
at::Tensor new_xyz_grid_idxs_tensor, at::Tensor new_xyz_grid_dist2_tensor,
at::Tensor stack_neighbor_idxs_tensor, at::Tensor start_len_tensor,
int M, int num_total_grids);
int vector_pool_wrapper_stack(at::Tensor support_xyz_tensor, at::Tensor xyz_batch_cnt_tensor,
at::Tensor support_features_tensor, at::Tensor new_xyz_tensor, at::Tensor new_xyz_batch_cnt_tensor,
at::Tensor new_features_tensor, at::Tensor new_local_xyz,
at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor,
int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance, int use_xyz,
int num_max_sum_points, int nsample, int neighbor_type, int pooling_type);
int vector_pool_kernel_launcher_stack(
const float *support_xyz, const float *support_features, const int *xyz_batch_cnt,
const float *new_xyz, float *new_features, float * new_local_xyz, const int *new_xyz_batch_cnt,
int *point_cnt_of_grid, int *grouped_idxs,
int num_grid_x, int num_grid_y, int num_grid_z, float max_neighbour_distance,
int batch_size, int N, int M, int num_c_in, int num_c_out, int num_total_grids, int use_xyz,
int num_max_sum_points, int nsample, int neighbor_type, int pooling_type);
int vector_pool_grad_wrapper_stack(at::Tensor grad_new_features_tensor,
at::Tensor point_cnt_of_grid_tensor, at::Tensor grouped_idxs_tensor,
at::Tensor grad_support_features_tensor);
void vector_pool_grad_kernel_launcher_stack(
const float *grad_new_features, const int *point_cnt_of_grid, const int *grouped_idxs,
float *grad_support_features, int N, int M, int num_c_out, int num_c_in, int num_total_grids,
int num_max_sum_points);
#endif
......@@ -97,6 +97,8 @@ if __name__ == '__main__':
'src/interpolate_gpu.cu',
'src/voxel_query.cpp',
'src/voxel_query_gpu.cu',
'src/vector_pool.cpp',
'src/vector_pool_gpu.cu'
],
),
make_cuda_ext(
......
CLASS_NAMES: ['Vehicle', 'Pedestrian', 'Cyclist']
DATA_CONFIG:
_BASE_CONFIG_: cfgs/dataset_configs/waymo_dataset.yaml
MODEL:
NAME: PVRCNNPlusPlus
VFE:
NAME: MeanVFE
BACKBONE_3D:
NAME: VoxelBackBone8x
MAP_TO_BEV:
NAME: HeightCompression
NUM_BEV_FEATURES: 256
BACKBONE_2D:
NAME: BaseBEVBackbone
LAYER_NUMS: [5, 5]
LAYER_STRIDES: [1, 2]
NUM_FILTERS: [128, 256]
UPSAMPLE_STRIDES: [1, 2]
NUM_UPSAMPLE_FILTERS: [256, 256]
DENSE_HEAD:
NAME: CenterHead
CLASS_AGNOSTIC: False
CLASS_NAMES_EACH_HEAD: [
[ 'Vehicle', 'Pedestrian', 'Cyclist' ]
]
SHARED_CONV_CHANNEL: 64
USE_BIAS_BEFORE_NORM: True
NUM_HM_CONV: 2
SEPARATE_HEAD_CFG:
HEAD_ORDER: [ 'center', 'center_z', 'dim', 'rot' ]
HEAD_DICT: {
'center': { 'out_channels': 2, 'num_conv': 2 },
'center_z': { 'out_channels': 1, 'num_conv': 2 },
'dim': { 'out_channels': 3, 'num_conv': 2 },
'rot': { 'out_channels': 2, 'num_conv': 2 },
}
TARGET_ASSIGNER_CONFIG:
FEATURE_MAP_STRIDE: 8
NUM_MAX_OBJS: 500
GAUSSIAN_OVERLAP: 0.1
MIN_RADIUS: 2
LOSS_CONFIG:
LOSS_WEIGHTS: {
'cls_weight': 1.0,
'loc_weight': 2.0,
'code_weights': [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
}
POST_PROCESSING:
SCORE_THRESH: 0.1
POST_CENTER_LIMIT_RANGE: [ -75.2, -75.2, -2, 75.2, 75.2, 4 ]
MAX_OBJ_PER_SAMPLE: 500
NMS_CONFIG:
NMS_TYPE: nms_gpu
NMS_THRESH: 0.7
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500
PFE:
NAME: VoxelSetAbstraction
POINT_SOURCE: raw_points
NUM_KEYPOINTS: 4096
NUM_OUTPUT_FEATURES: 90
SAMPLE_METHOD: SPC
SPC_SAMPLING:
NUM_SECTORS: 6
SAMPLE_RADIUS_WITH_ROI: 1.6
FEATURES_SOURCE: ['bev', 'x_conv3', 'x_conv4', 'raw_points']
SA_LAYER:
raw_points:
NAME: VectorPoolAggregationModuleMSG
NUM_GROUPS: 2
LOCAL_AGGREGATION_TYPE: local_interpolation
NUM_REDUCED_CHANNELS: 2
NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32
MSG_POST_MLPS: [ 32 ]
FILTER_NEIGHBOR_WITH_ROI: True
RADIUS_OF_NEIGHBOR_WITH_ROI: 2.4
GROUP_CFG_0:
NUM_LOCAL_VOXEL: [ 2, 2, 2 ]
MAX_NEIGHBOR_DISTANCE: 0.2
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 32, 32 ]
GROUP_CFG_1:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 0.4
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 32, 32 ]
x_conv3:
DOWNSAMPLE_FACTOR: 4
INPUT_CHANNELS: 64
NAME: VectorPoolAggregationModuleMSG
NUM_GROUPS: 2
LOCAL_AGGREGATION_TYPE: local_interpolation
NUM_REDUCED_CHANNELS: 32
NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32
MSG_POST_MLPS: [128]
FILTER_NEIGHBOR_WITH_ROI: True
RADIUS_OF_NEIGHBOR_WITH_ROI: 4.0
GROUP_CFG_0:
NUM_LOCAL_VOXEL: [3, 3, 3]
MAX_NEIGHBOR_DISTANCE: 1.2
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [64, 64]
GROUP_CFG_1:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 2.4
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 64, 64 ]
x_conv4:
DOWNSAMPLE_FACTOR: 8
INPUT_CHANNELS: 64
NAME: VectorPoolAggregationModuleMSG
NUM_GROUPS: 2
LOCAL_AGGREGATION_TYPE: local_interpolation
NUM_REDUCED_CHANNELS: 32
NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32
MSG_POST_MLPS: [ 128 ]
FILTER_NEIGHBOR_WITH_ROI: True
RADIUS_OF_NEIGHBOR_WITH_ROI: 6.4
GROUP_CFG_0:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 2.4
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 64, 64 ]
GROUP_CFG_1:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 4.8
NEIGHBOR_NSAMPLE: -1
POST_MLPS: [ 64, 64 ]
POINT_HEAD:
NAME: PointHeadSimple
CLS_FC: [256, 256]
CLASS_AGNOSTIC: True
USE_POINT_FEATURES_BEFORE_FUSION: True
TARGET_CONFIG:
GT_EXTRA_WIDTH: [0.2, 0.2, 0.2]
LOSS_CONFIG:
LOSS_REG: smooth-l1
LOSS_WEIGHTS: {
'point_cls_weight': 1.0,
}
ROI_HEAD:
NAME: PVRCNNHead
CLASS_AGNOSTIC: True
SHARED_FC: [256, 256]
CLS_FC: [256, 256]
REG_FC: [256, 256]
DP_RATIO: 0.3
NMS_CONFIG:
TRAIN:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 9000
NMS_POST_MAXSIZE: 512
NMS_THRESH: 0.8
TEST:
NMS_TYPE: nms_gpu
MULTI_CLASSES_NMS: False
NMS_PRE_MAXSIZE: 1024
NMS_POST_MAXSIZE: 100
NMS_THRESH: 0.7
SCORE_THRESH: 0.1
# NMS_PRE_MAXSIZE: 4096
# NMS_POST_MAXSIZE: 500
# NMS_THRESH: 0.85
ROI_GRID_POOL:
GRID_SIZE: 6
NAME: VectorPoolAggregationModuleMSG
NUM_GROUPS: 2
LOCAL_AGGREGATION_TYPE: voxel_random_choice
NUM_REDUCED_CHANNELS: 30
NUM_CHANNELS_OF_LOCAL_AGGREGATION: 32
MSG_POST_MLPS: [ 128 ]
GROUP_CFG_0:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 0.8
NEIGHBOR_NSAMPLE: 32
POST_MLPS: [ 64, 64 ]
GROUP_CFG_1:
NUM_LOCAL_VOXEL: [ 3, 3, 3 ]
MAX_NEIGHBOR_DISTANCE: 1.6
NEIGHBOR_NSAMPLE: 32
POST_MLPS: [ 64, 64 ]
TARGET_CONFIG:
BOX_CODER: ResidualCoder
ROI_PER_IMAGE: 128
FG_RATIO: 0.5
SAMPLE_ROI_BY_EACH_CLASS: True
CLS_SCORE_TYPE: roi_iou
CLS_FG_THRESH: 0.75
CLS_BG_THRESH: 0.25
CLS_BG_THRESH_LO: 0.1
HARD_BG_RATIO: 0.8
REG_FG_THRESH: 0.55
LOSS_CONFIG:
CLS_LOSS: BinaryCrossEntropy
REG_LOSS: smooth-l1
CORNER_LOSS_REGULARIZATION: True
LOSS_WEIGHTS: {
'rcnn_cls_weight': 1.0,
'rcnn_reg_weight': 1.0,
'rcnn_corner_weight': 1.0,
'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
}
POST_PROCESSING:
RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
SCORE_THRESH: 0.1
OUTPUT_RAW_SCORE: False
EVAL_METRIC: waymo
NMS_CONFIG:
MULTI_CLASSES_NMS: False
NMS_TYPE: nms_gpu
NMS_THRESH: 0.7
NMS_PRE_MAXSIZE: 4096
NMS_POST_MAXSIZE: 500
OPTIMIZATION:
BATCH_SIZE_PER_GPU: 2
NUM_EPOCHS: 30
OPTIMIZER: adam_onecycle
LR: 0.01
WEIGHT_DECAY: 0.001
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001
LR_WARMUP: False
WARMUP_EPOCH: 1
GRAD_NORM_CLIP: 10
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment