"lib/bindings/c/Cargo.lock" did not exist on "ffbc06ccf7c9abb40123f3d6ea047caff4609c6c"
Unverified Commit a991105c authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

Release the codes of PV-RCNN++, update OpenPCDet to v0.5.2

Release the codes of PV-RCNN++, update OpenPCDet to v0.5.2
parents 1483517a b6fbf07f
......@@ -4,9 +4,11 @@
`OpenPCDet` is a clear, simple, self-contained open source project for LiDAR-based 3D object detection.
It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/1812.04244), [`[Part-A2-Net]`](https://arxiv.org/abs/1907.03670), [`[PV-RCNN]`](https://arxiv.org/abs/1912.13192) and [`[Voxel R-CNN]`](https://arxiv.org/abs/2012.15712).
It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/1812.04244), [`[Part-A2-Net]`](https://arxiv.org/abs/1907.03670), [`[PV-RCNN]`](https://arxiv.org/abs/1912.13192), [`[Voxel R-CNN]`](https://arxiv.org/abs/2012.15712) and [`[PV-RCNN++]`](https://arxiv.org/abs/2102.00463).
**NEW**: `OpenPCDet` has been updated to `v0.5.0` (Dec. 2021).
**Highlights**:
* `OpenPCDet` has been updated to `v0.5.2` (Jan. 2022).
* The codes of PV-RCNN++ has been supported.
## Overview
- [Changelog](#changelog)
......@@ -19,6 +21,12 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
## Changelog
[2022-01-05] **NEW:** Update `OpenPCDet` to v0.5.2:
* The code of [PV-RCNN++](https://arxiv.org/abs/2102.00463) has been released to this repo, with higher performance, faster training/inference speed and less memory consumption than PV-RCNN.
* Add performance of several models trained with full training set of [Waymo Open Dataset](#waymo-open-dataset-baselines).
* Support Lyft dataset, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/720).
[2021-12-09] **NEW:** Update `OpenPCDet` to v0.5.1:
* Add PointPillar related baseline configs/results on [Waymo Open Dataset](#waymo-open-dataset-baselines).
* Support Pandaset dataloader, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/396).
......@@ -108,7 +116,7 @@ Contributions are also welcomed.
### KITTI 3D Object Detection Baselines
Selected supported methods are shown in the below table. The results are the 3D detection performance of moderate difficulty on the *val* set of KITTI dataset.
* All models are trained with 8 GTX 1080Ti GPUs and are available for download.
* All LiDAR-based models are trained with 8 GTX 1080Ti GPUs and are available for download.
* The training time is measured with 8 TITAN XP GPUs and PyTorch 1.5.
| | training time | Car@R11 | Pedestrian@R11 | Cyclist@R11 | download |
......@@ -129,7 +137,7 @@ Selected supported methods are shown in the below table. The results are the 3D
We provide the setting of [`DATA_CONFIG.SAMPLED_INTERVAL`](tools/cfgs/dataset_configs/waymo_dataset.yaml) on the Waymo Open Dataset (WOD) to subsample partial samples for training and evaluation,
so you could also play with WOD by setting a smaller `DATA_CONFIG.SAMPLED_INTERVAL` even if you only have limited GPU resources.
By default, all models are trained with **20% data (~32k frames)** of all the training samples on 8 GTX 1080Ti GPUs, and the results of each cell here are mAP/mAPH calculated by the official Waymo evaluation metrics on the **whole** validation set (version 1.2).
By default, all models are trained with **a single frame** of **20% data (~32k frames)** of all the training samples on 8 GTX 1080Ti GPUs, and the results of each cell here are mAP/mAPH calculated by the official Waymo evaluation metrics on the **whole** validation set (version 1.2).
| Performance@(train with 20\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
......@@ -141,6 +149,22 @@ By default, all models are trained with **20% data (~32k frames)** of all the tr
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 74.66/74.12 |65.82/65.32 |71.71/62.24 |62.46/54.06 |66.53/65.18 |64.05/62.75 |
| [PV-RCNN (AnchorHead)](tools/cfgs/waymo_models/pv_rcnn.yaml) | 75.41/74.74 |67.44/66.80 |71.98/61.24 |63.70/53.95 |65.88/64.25 |63.39/61.82 |
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 75.95/75.43 |68.02/67.54 |75.94/69.40 |67.66/61.62 |70.18/68.98 |67.73/66.57|
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 77.82/77.32| 69.07/68.62| 77.99/71.36| 69.92/63.74| 71.80/70.71| 69.31/68.26|
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) |77.61/77.14| 69.18/68.75| 79.42/73.31| 70.88/65.21| 72.50/71.39| 69.84/68.77|
Here we also provide the performance of several models trained on the full training set (refer to the paper of [PV-RCNN++](https://arxiv.org/abs/2102.00463)):
| Performance@(train with 100\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
| [SECOND](tools/cfgs/waymo_models/second.yaml) | 72.27/71.69 | 63.85/63.33 | 68.70/58.18 | 60.72/51.31 | 60.62/59.28 | 58.34/57.05|
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 77.05/76.51 | 68.47/67.97 | 75.24/66.87 | 66.18/58.62 | 68.60/67.36 | 66.13/64.93 |
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79|
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) |79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19|
We could not provide the above pretrained models due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/),
but you could easily achieve similar performance by training with the default configs.
......
......@@ -86,7 +86,7 @@ OpenPCDet
```shell script
pip3 install --upgrade pip
# tf 2.0.0
pip3 install waymo-open-dataset-tf-2-0-0==1.2.0 --user
pip3 install waymo-open-dataset-tf-2-5-0 --user
```
* Extract point cloud data from tfrecord and generate data infos by running the following command (it takes several hours,
......
import math
import numpy as np
import torch
import torch.nn as nn
......@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y):
return ans
def sample_points_with_roi(rois, points, sample_radius_with_roi, num_max_points_of_part=200000):
"""
Args:
rois: (M, 7 + C)
points: (N, 3)
sample_radius_with_roi:
num_max_points_of_part:
Returns:
sampled_points: (N_out, 3)
"""
if points.shape[0] < num_max_points_of_part:
distance = (points[:, None, :] - rois[None, :, 0:3]).norm(dim=-1)
min_dis, min_dis_roi_idx = distance.min(dim=-1)
roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1)
point_mask = min_dis < roi_max_dim + sample_radius_with_roi
else:
start_idx = 0
point_mask_list = []
while start_idx < points.shape[0]:
distance = (points[start_idx:start_idx + num_max_points_of_part, None, :] - rois[None, :, 0:3]).norm(dim=-1)
min_dis, min_dis_roi_idx = distance.min(dim=-1)
roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1)
cur_point_mask = min_dis < roi_max_dim + sample_radius_with_roi
point_mask_list.append(cur_point_mask)
start_idx += num_max_points_of_part
point_mask = torch.cat(point_mask_list, dim=0)
sampled_points = points[:1] if point_mask.sum() == 0 else points[point_mask, :]
return sampled_points, point_mask
def sector_fps(points, num_sampled_points, num_sectors):
"""
Args:
points: (N, 3)
num_sampled_points: int
num_sectors: int
Returns:
sampled_points: (N_out, 3)
"""
sector_size = np.pi * 2 / num_sectors
point_angles = torch.atan2(points[:, 1], points[:, 0]) + np.pi
sector_idx = (point_angles / sector_size).floor().clamp(min=0, max=num_sectors)
xyz_points_list = []
xyz_batch_cnt = []
num_sampled_points_list = []
for k in range(num_sectors):
mask = (sector_idx == k)
cur_num_points = mask.sum().item()
if cur_num_points > 0:
xyz_points_list.append(points[mask])
xyz_batch_cnt.append(cur_num_points)
ratio = cur_num_points / points.shape[0]
num_sampled_points_list.append(
min(cur_num_points, math.ceil(ratio * num_sampled_points))
)
if len(xyz_batch_cnt) == 0:
xyz_points_list.append(points)
xyz_batch_cnt.append(len(points))
num_sampled_points_list.append(num_sampled_points)
print(f'Warning: empty sector points detected in SectorFPS: points.shape={points.shape}')
xyz = torch.cat(xyz_points_list, dim=0)
xyz_batch_cnt = torch.tensor(xyz_batch_cnt, device=points.device).int()
sampled_points_batch_cnt = torch.tensor(num_sampled_points_list, device=points.device).int()
sampled_pt_idxs = pointnet2_stack_utils.stack_farthest_point_sample(
xyz.contiguous(), xyz_batch_cnt, sampled_points_batch_cnt
).long()
sampled_points = xyz[sampled_pt_idxs]
return sampled_points
class VoxelSetAbstraction(nn.Module):
def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None,
num_rawpoint_features=None, **kwargs):
......@@ -58,38 +139,31 @@ class VoxelSetAbstraction(nn.Module):
if src_name in ['bev', 'raw_points']:
continue
self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR
mlps = SA_cfg[src_name].MLPS
for k in range(len(mlps)):
mlps[k] = [mlps[k][0]] + mlps[k]
cur_layer = pointnet2_stack_modules.StackSAModuleMSG(
radii=SA_cfg[src_name].POOL_RADIUS,
nsamples=SA_cfg[src_name].NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method='max_pool',
if SA_cfg[src_name].get('INPUT_CHANNELS', None) is None:
input_channels = SA_cfg[src_name].MLPS[0][0] \
if isinstance(SA_cfg[src_name].MLPS[0], list) else SA_cfg[src_name].MLPS[0]
else:
input_channels = SA_cfg[src_name]['INPUT_CHANNELS']
cur_layer, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
input_channels=input_channels, config=SA_cfg[src_name]
)
self.SA_layers.append(cur_layer)
self.SA_layer_names.append(src_name)
c_in += sum([x[-1] for x in mlps])
c_in += cur_num_c_out
if 'bev' in self.model_cfg.FEATURES_SOURCE:
c_bev = num_bev_features
c_in += c_bev
if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
mlps = SA_cfg['raw_points'].MLPS
for k in range(len(mlps)):
mlps[k] = [num_rawpoint_features - 3] + mlps[k]
self.SA_rawpoints = pointnet2_stack_modules.StackSAModuleMSG(
radii=SA_cfg['raw_points'].POOL_RADIUS,
nsamples=SA_cfg['raw_points'].NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method='max_pool'
self.SA_rawpoints, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
input_channels=num_rawpoint_features - 3, config=SA_cfg['raw_points']
)
c_in += sum([x[-1] for x in mlps])
c_in += cur_num_c_out
self.vsa_point_feature_fusion = nn.Sequential(
nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False),
......@@ -100,23 +174,64 @@ class VoxelSetAbstraction(nn.Module):
self.num_point_features_before_fusion = c_in
def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride):
x_idxs = (keypoints[:, :, 0] - self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[:, :, 1] - self.point_cloud_range[1]) / self.voxel_size[1]
"""
Args:
keypoints: (N1 + N2 + ..., 4)
bev_features: (B, C, H, W)
batch_size:
bev_stride:
Returns:
point_bev_features: (N1 + N2 + ..., C)
"""
x_idxs = (keypoints[:, 1] - self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[:, 2] - self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_stride
y_idxs = y_idxs / bev_stride
point_bev_features_list = []
for k in range(batch_size):
cur_x_idxs = x_idxs[k]
cur_y_idxs = y_idxs[k]
bs_mask = (keypoints[:, 0] == k)
cur_x_idxs = x_idxs[bs_mask]
cur_y_idxs = y_idxs[bs_mask]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features.unsqueeze(dim=0))
point_bev_features_list.append(point_bev_features)
point_bev_features = torch.cat(point_bev_features_list, dim=0) # (B, N, C0)
point_bev_features = torch.cat(point_bev_features_list, dim=0) # (N1 + N2 + ..., C)
return point_bev_features
def sectorized_proposal_centric_sampling(self, roi_boxes, points):
"""
Args:
roi_boxes: (M, 7 + C)
points: (N, 3)
Returns:
sampled_points: (N_out, 3)
"""
sampled_points, _ = sample_points_with_roi(
rois=roi_boxes, points=points,
sample_radius_with_roi=self.model_cfg.SPC_SAMPLING.SAMPLE_RADIUS_WITH_ROI,
num_max_points_of_part=self.model_cfg.SPC_SAMPLING.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000)
)
sampled_points = sector_fps(
points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS,
num_sectors=self.model_cfg.SPC_SAMPLING.NUM_SECTORS
)
return sampled_points
def get_sampled_points(self, batch_dict):
"""
Args:
batch_dict:
Returns:
keypoints: (N1 + N2 + ..., 4), where 4 indicates [bs_idx, x, y, z]
"""
batch_size = batch_dict['batch_size']
if self.model_cfg.POINT_SOURCE == 'raw_points':
src_points = batch_dict['points'][:, 1:4]
......@@ -136,7 +251,7 @@ class VoxelSetAbstraction(nn.Module):
bs_mask = (batch_indices == bs_idx)
sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3)
if self.model_cfg.SAMPLE_METHOD == 'FPS':
cur_pt_idxs = pointnet2_stack_utils.furthest_point_sample(
cur_pt_idxs = pointnet2_stack_utils.farthest_point_sample(
sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS
).long()
......@@ -147,16 +262,75 @@ class VoxelSetAbstraction(nn.Module):
keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0)
elif self.model_cfg.SAMPLE_METHOD == 'FastFPS':
raise NotImplementedError
elif self.model_cfg.SAMPLE_METHOD == 'SPC':
cur_keypoints = self.sectorized_proposal_centric_sampling(
roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points[0]
)
bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx
keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1)
else:
raise NotImplementedError
keypoints_list.append(keypoints)
keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3)
keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) or (N1 + N2 + ..., 4)
if len(keypoints.shape) == 3:
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1, 1)
keypoints = torch.cat((batch_idx.float(), keypoints.view(-1, 3)), dim=1)
return keypoints
@staticmethod
def aggregate_keypoint_features_from_one_source(
batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt,
filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=200000, rois=None
):
"""
Args:
aggregate_func:
xyz: (N, 3)
xyz_features: (N, C)
xyz_bs_idxs: (N)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size), [N1, N2, ...]
filter_neighbors_with_roi: True/False
radius_of_neighbor: float
num_max_points_of_part: int
rois: (batch_size, num_rois, 7 + C)
Returns:
"""
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
if filter_neighbors_with_roi:
point_features = torch.cat((xyz, xyz_features), dim=-1) if xyz_features is not None else xyz
point_features_list = []
for bs_idx in range(batch_size):
bs_mask = (xyz_bs_idxs == bs_idx)
_, valid_mask = sample_points_with_roi(
rois=rois[bs_idx], points=xyz[bs_mask],
sample_radius_with_roi=radius_of_neighbor, num_max_points_of_part=num_max_points_of_part,
)
point_features_list.append(point_features[bs_mask][valid_mask])
xyz_batch_cnt[bs_idx] = valid_mask.sum()
valid_point_features = torch.cat(point_features_list, dim=0)
xyz = valid_point_features[:, 0:3]
xyz_features = valid_point_features[:, 3:] if xyz_features is not None else None
else:
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (xyz_bs_idxs == bs_idx).sum()
pooled_points, pooled_features = aggregate_func(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=xyz_features.contiguous(),
)
return pooled_features
def forward(self, batch_dict):
"""
Args:
......@@ -185,56 +359,53 @@ class VoxelSetAbstraction(nn.Module):
)
point_features_list.append(point_bev_features)
batch_size, num_keypoints, _ = keypoints.shape
new_xyz = keypoints.view(-1, 3)
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int().fill_(num_keypoints)
batch_size = batch_dict['batch_size']
new_xyz = keypoints[:, 1:4].contiguous()
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int()
for k in range(batch_size):
new_xyz_batch_cnt[k] = (keypoints[:, 0] == k).sum()
if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
raw_points = batch_dict['points']
xyz = raw_points[:, 1:4]
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (raw_points[:, 0] == bs_idx).sum()
point_features = raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None
pooled_points, pooled_features = self.SA_rawpoints(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=point_features,
pooled_features = self.aggregate_keypoint_features_from_one_source(
batch_size=batch_size, aggregate_func=self.SA_rawpoints,
xyz=raw_points[:, 1:4],
xyz_features=raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None,
xyz_bs_idxs=raw_points[:, 0],
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt,
filter_neighbors_with_roi=self.model_cfg.SA_LAYER['raw_points'].get('FILTER_NEIGHBOR_WITH_ROI', False),
radius_of_neighbor=self.model_cfg.SA_LAYER['raw_points'].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None),
rois=batch_dict.get('rois', None)
)
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1))
point_features_list.append(pooled_features)
for k, src_name in enumerate(self.SA_layer_names):
cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices
cur_features = batch_dict['multi_scale_3d_features'][src_name].features.contiguous()
xyz = common_utils.get_voxel_centers(
cur_coords[:, 1:4],
downsample_times=self.downsample_times_map[src_name],
voxel_size=self.voxel_size,
point_cloud_range=self.point_cloud_range
cur_coords[:, 1:4], downsample_times=self.downsample_times_map[src_name],
voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range
)
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum()
pooled_points, pooled_features = self.SA_layers[k](
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=batch_dict['multi_scale_3d_features'][src_name].features.contiguous(),
pooled_features = self.aggregate_keypoint_features_from_one_source(
batch_size=batch_size, aggregate_func=self.SA_layers[k],
xyz=xyz.contiguous(), xyz_features=cur_features, xyz_bs_idxs=cur_coords[:, 0],
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt,
filter_neighbors_with_roi=self.model_cfg.SA_LAYER[src_name].get('FILTER_NEIGHBOR_WITH_ROI', False),
radius_of_neighbor=self.model_cfg.SA_LAYER[src_name].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None),
rois=batch_dict.get('rois', None)
)
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1))
point_features = torch.cat(point_features_list, dim=2)
point_features_list.append(pooled_features)
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1)
point_coords = torch.cat((batch_idx.view(-1, 1).float(), keypoints.view(-1, 3)), dim=1)
point_features = torch.cat(point_features_list, dim=-1)
batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1])
point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1]))
batch_dict['point_features'] = point_features # (BxN, C)
batch_dict['point_coords'] = point_coords # (BxN, 4)
batch_dict['point_coords'] = keypoints # (BxN, 4)
return batch_dict
......@@ -174,7 +174,7 @@ class PointNet2Backbone(nn.Module):
else:
last_num_points = self.num_points_each_layer[i - 1]
cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points]
cur_pt_idxs = pointnet2_utils_stack.furthest_point_sample(
cur_pt_idxs = pointnet2_utils_stack.farthest_point_sample(
cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i]
).long()[0]
if cur_xyz.shape[0] < self.num_points_each_layer[i]:
......
......@@ -8,6 +8,7 @@ from .second_net_iou import SECONDNetIoU
from .caddn import CaDDN
from .voxel_rcnn import VoxelRCNN
from .centerpoint import CenterPoint
from .pv_rcnn_plusplus import PVRCNNPlusPlus
__all__ = {
'Detector3DTemplate': Detector3DTemplate,
......@@ -19,7 +20,8 @@ __all__ = {
'SECONDNetIoU': SECONDNetIoU,
'CaDDN': CaDDN,
'VoxelRCNN': VoxelRCNN,
'CenterPoint': CenterPoint
'CenterPoint': CenterPoint,
'PVRCNNPlusPlus': PVRCNNPlusPlus
}
......
from .detector3d_template import Detector3DTemplate
class PVRCNNPlusPlus(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
batch_dict = self.vfe(batch_dict)
batch_dict = self.backbone_3d(batch_dict)
batch_dict = self.map_to_bev_module(batch_dict)
batch_dict = self.backbone_2d(batch_dict)
batch_dict = self.dense_head(batch_dict)
batch_dict = self.roi_head.proposal_layer(
batch_dict, nms_config=self.roi_head.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.roi_head.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
batch_dict['roi_targets_dict'] = targets_dict
num_rois_per_scene = targets_dict['rois'].shape[1]
if 'roi_valid_num' in batch_dict:
batch_dict['roi_valid_num'] = [num_rois_per_scene for _ in range(batch_dict['batch_size'])]
batch_dict = self.pfe(batch_dict)
batch_dict = self.point_head(batch_dict)
batch_dict = self.roi_head(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
if self.point_head is not None:
loss_point, tb_dict = self.point_head.get_loss(tb_dict)
else:
loss_point = 0
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss_rpn + loss_point + loss_rcnn
return loss, tb_dict, disp_dict
......@@ -10,21 +10,12 @@ class PVRCNNHead(RoIHeadTemplate):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
mlps = self.model_cfg.ROI_GRID_POOL.MLPS
for k in range(len(mlps)):
mlps[k] = [input_channels] + mlps[k]
self.roi_grid_pool_layer = pointnet2_stack_modules.StackSAModuleMSG(
radii=self.model_cfg.ROI_GRID_POOL.POOL_RADIUS,
nsamples=self.model_cfg.ROI_GRID_POOL.NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method=self.model_cfg.ROI_GRID_POOL.POOL_METHOD,
self.roi_grid_pool_layer, num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
input_channels=input_channels, config=self.model_cfg.ROI_GRID_POOL
)
GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
c_out = sum([x[-1] for x in mlps])
pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * c_out
pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * num_c_out
shared_fc_list = []
for k in range(0, self.model_cfg.SHARED_FC.__len__()):
......@@ -150,6 +141,8 @@ class PVRCNNHead(RoIHeadTemplate):
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = batch_dict.get('roi_targets_dict', None)
if targets_dict is None:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
......
......@@ -31,7 +31,7 @@ class _PointnetSAModuleBase(nn.Module):
if new_xyz is None:
new_xyz = pointnet2_utils.gather_operation(
xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
pointnet2_utils.farthest_point_sample(xyz, self.npoint)
).transpose(1, 2).contiguous() if self.npoint is not None else None
for i in range(len(self.groupers)):
......
......@@ -7,11 +7,11 @@ from torch.autograd import Function, Variable
from . import pointnet2_batch_cuda as pointnet2
class FurthestPointSampling(Function):
class FarthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
Uses iterative farthest point sampling to select a set of npoint features that have the largest
minimum distance
:param ctx:
:param xyz: (B, N, 3) where N > npoint
......@@ -25,7 +25,7 @@ class FurthestPointSampling(Function):
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
pointnet2.farthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
@staticmethod
......@@ -33,7 +33,7 @@ class FurthestPointSampling(Function):
return None, None
furthest_point_sample = FurthestPointSampling.apply
farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class GatherOperation(Function):
......
......@@ -16,7 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
......
......@@ -38,13 +38,13 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
}
int furthest_point_sampling_wrapper(int b, int n, int m,
int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1;
}
......@@ -98,7 +98,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
}
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
__global__ void farthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
......@@ -215,7 +215,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
}
}
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
......@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch (n_threads) {
case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 512:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 256:
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 128:
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 64:
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 32:
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 16:
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 8:
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 4:
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 2:
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 1:
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
default:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
......
......@@ -20,10 +20,10 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, float *grad_points);
int furthest_point_sampling_wrapper(int b, int n, int m,
int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs);
#endif
......@@ -7,6 +7,26 @@ import torch.nn.functional as F
from . import pointnet2_utils
def build_local_aggregation_module(input_channels, config):
local_aggregation_name = config.get('NAME', 'StackSAModuleMSG')
if local_aggregation_name == 'StackSAModuleMSG':
mlps = config.MLPS
for k in range(len(mlps)):
mlps[k] = [input_channels] + mlps[k]
cur_layer = StackSAModuleMSG(
radii=config.POOL_RADIUS, nsamples=config.NSAMPLE, mlps=mlps, use_xyz=True, pool_method='max_pool',
)
num_c_out = sum([x[-1] for x in mlps])
elif local_aggregation_name == 'VectorPoolAggregationModuleMSG':
cur_layer = VectorPoolAggregationModuleMSG(input_channels=input_channels, config=config)
num_c_out = config.MSG_POST_MLPS[-1]
else:
raise NotImplementedError
return cur_layer, num_c_out
class StackSAModuleMSG(nn.Module):
def __init__(self, *, radii: List[float], nsamples: List[int], mlps: List[List[int]],
......@@ -135,3 +155,316 @@ class StackPointnetFPModule(nn.Module):
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features
class VectorPoolLocalInterpolateModule(nn.Module):
def __init__(self, mlp, num_voxels, max_neighbour_distance, nsample, neighbor_type, use_xyz=True,
neighbour_distance_multiplier=1.0, xyz_encoding_type='concat'):
"""
Args:
mlp:
num_voxels:
max_neighbour_distance:
neighbor_type: 1: ball, others: cube
nsample: find all (-1), find limited number(>0)
use_xyz:
neighbour_distance_multiplier:
xyz_encoding_type:
"""
super().__init__()
self.num_voxels = num_voxels # [num_grid_x, num_grid_y, num_grid_z]: number of grids in each local area centered at new_xyz
self.num_total_grids = self.num_voxels[0] * self.num_voxels[1] * self.num_voxels[2]
self.max_neighbour_distance = max_neighbour_distance
self.neighbor_distance_multiplier = neighbour_distance_multiplier
self.nsample = nsample
self.neighbor_type = neighbor_type
self.use_xyz = use_xyz
self.xyz_encoding_type = xyz_encoding_type
if mlp is not None:
if self.use_xyz:
mlp[0] += 9 if self.xyz_encoding_type == 'concat' else 0
shared_mlps = []
for k in range(len(mlp) - 1):
shared_mlps.extend([
nn.Conv2d(mlp[k], mlp[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp[k + 1]),
nn.ReLU()
])
self.mlp = nn.Sequential(*shared_mlps)
else:
self.mlp = None
self.num_avg_length_of_neighbor_idxs = 1000
def forward(self, support_xyz, support_features, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt):
"""
Args:
support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
support_features: (N1 + N2 ..., C) point-wise features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
new_xyz: (M1 + M2 ..., 3) centers of the ball query
new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
Returns:
new_features: (N1 + N2 ..., C_out)
"""
with torch.no_grad():
dist, idx, num_avg_length_of_neighbor_idxs = pointnet2_utils.three_nn_for_vector_pool_by_two_step(
support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt,
self.max_neighbour_distance, self.nsample, self.neighbor_type,
self.num_avg_length_of_neighbor_idxs, self.num_total_grids, self.neighbor_distance_multiplier
)
self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs.item())
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=-1, keepdim=True)
weight = dist_recip / torch.clamp_min(norm, min=1e-8)
empty_mask = (idx.view(-1, 3)[:, 0] == -1)
idx.view(-1, 3)[empty_mask] = 0
interpolated_feats = pointnet2_utils.three_interpolate(support_features, idx.view(-1, 3), weight.view(-1, 3))
interpolated_feats = interpolated_feats.view(idx.shape[0], idx.shape[1], -1) # (M1 + M2 ..., num_total_grids, C)
if self.use_xyz:
near_known_xyz = support_xyz[idx.view(-1, 3).long()].view(-1, 3, 3) # ( (M1 + M2 ...)*num_total_grids, 3)
local_xyz = (new_xyz_grid_centers.view(-1, 1, 3) - near_known_xyz).view(-1, idx.shape[1], 9)
if self.xyz_encoding_type == 'concat':
interpolated_feats = torch.cat((interpolated_feats, local_xyz), dim=-1) # ( M1 + M2 ..., num_total_grids, 9+C)
else:
raise NotImplementedError
new_features = interpolated_feats.view(-1, interpolated_feats.shape[-1]) # ((M1 + M2 ...) * num_total_grids, C)
new_features[empty_mask, :] = 0
if self.mlp is not None:
new_features = new_features.permute(1, 0)[None, :, :, None] # (1, C, N1 + N2 ..., 1)
new_features = self.mlp(new_features)
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features
class VectorPoolAggregationModule(nn.Module):
def __init__(
self, input_channels, num_local_voxel=(3, 3, 3), local_aggregation_type='local_interpolation',
num_reduced_channels=30, num_channels_of_local_aggregation=32, post_mlps=(128,),
max_neighbor_distance=None, neighbor_nsample=-1, neighbor_type=0, neighbor_distance_multiplier=2.0):
super().__init__()
self.num_local_voxel = num_local_voxel
self.total_voxels = self.num_local_voxel[0] * self.num_local_voxel[1] * self.num_local_voxel[2]
self.local_aggregation_type = local_aggregation_type
assert self.local_aggregation_type in ['local_interpolation', 'voxel_avg_pool', 'voxel_random_choice']
self.input_channels = input_channels
self.num_reduced_channels = input_channels if num_reduced_channels is None else num_reduced_channels
self.num_channels_of_local_aggregation = num_channels_of_local_aggregation
self.max_neighbour_distance = max_neighbor_distance
self.neighbor_nsample = neighbor_nsample
self.neighbor_type = neighbor_type # 1: ball, others: cube
if self.local_aggregation_type == 'local_interpolation':
self.local_interpolate_module = VectorPoolLocalInterpolateModule(
mlp=None, num_voxels=self.num_local_voxel,
max_neighbour_distance=self.max_neighbour_distance,
nsample=self.neighbor_nsample,
neighbor_type=self.neighbor_type,
neighbour_distance_multiplier=neighbor_distance_multiplier,
)
num_c_in = (self.num_reduced_channels + 9) * self.total_voxels
else:
self.local_interpolate_module = None
num_c_in = (self.num_reduced_channels + 3) * self.total_voxels
num_c_out = self.total_voxels * self.num_channels_of_local_aggregation
self.separate_local_aggregation_layer = nn.Sequential(
nn.Conv1d(num_c_in, num_c_out, kernel_size=1, groups=self.total_voxels, bias=False),
nn.BatchNorm1d(num_c_out),
nn.ReLU()
)
post_mlp_list = []
c_in = num_c_out
for cur_num_c in post_mlps:
post_mlp_list.extend([
nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False),
nn.BatchNorm1d(cur_num_c),
nn.ReLU()
])
c_in = cur_num_c
self.post_mlps = nn.Sequential(*post_mlp_list)
self.num_mean_points_per_grid = 20
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def extra_repr(self) -> str:
ret = f'radius={self.max_neighbour_distance}, local_voxels=({self.num_local_voxel}, ' \
f'local_aggregation_type={self.local_aggregation_type}, ' \
f'num_c_reduction={self.input_channels}->{self.num_reduced_channels}, ' \
f'num_c_local_aggregation={self.num_channels_of_local_aggregation}'
return ret
def vector_pool_with_voxel_query(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt):
use_xyz = 1
pooling_type = 0 if self.local_aggregation_type == 'voxel_avg_pool' else 1
new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid = pointnet2_utils.vector_pool_with_voxel_query_op(
xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt,
self.num_local_voxel[0], self.num_local_voxel[1], self.num_local_voxel[2],
self.max_neighbour_distance, self.num_reduced_channels, use_xyz,
self.num_mean_points_per_grid, self.neighbor_nsample, self.neighbor_type,
pooling_type
)
self.num_mean_points_per_grid = max(self.num_mean_points_per_grid, num_mean_points_per_grid.item())
num_new_pts = new_features.shape[0]
new_local_xyz = new_local_xyz.view(num_new_pts, -1, 3) # (N, num_voxel, 3)
new_features = new_features.view(num_new_pts, -1, self.num_reduced_channels) # (N, num_voxel, C)
new_features = torch.cat((new_local_xyz, new_features), dim=-1).view(num_new_pts, -1)
return new_features, point_cnt_of_grid
@staticmethod
def get_dense_voxels_by_center(point_centers, max_neighbour_distance, num_voxels):
"""
Args:
point_centers: (N, 3)
max_neighbour_distance: float
num_voxels: [num_x, num_y, num_z]
Returns:
voxel_centers: (N, total_voxels, 3)
"""
R = max_neighbour_distance
device = point_centers.device
x_grids = torch.arange(-R + R / num_voxels[0], R - R / num_voxels[0] + 1e-5, 2 * R / num_voxels[0], device=device)
y_grids = torch.arange(-R + R / num_voxels[1], R - R / num_voxels[1] + 1e-5, 2 * R / num_voxels[1], device=device)
z_grids = torch.arange(-R + R / num_voxels[2], R - R / num_voxels[2] + 1e-5, 2 * R / num_voxels[2], device=device)
x_offset, y_offset, z_offset = torch.meshgrid(x_grids, y_grids, z_grids) # shape: [num_x, num_y, num_z]
xyz_offset = torch.cat((
x_offset.contiguous().view(-1, 1),
y_offset.contiguous().view(-1, 1),
z_offset.contiguous().view(-1, 1)), dim=-1
)
voxel_centers = point_centers[:, None, :] + xyz_offset[None, :, :]
return voxel_centers
def vector_pool_with_local_interpolate(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt):
"""
Args:
xyz: (N, 3)
xyz_batch_cnt: (batch_size)
features: (N, C)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size)
Returns:
new_features: (M, total_voxels * C)
"""
voxel_centers = self.get_dense_voxels_by_center(
point_centers=new_xyz, max_neighbour_distance=self.max_neighbour_distance, num_voxels=self.num_local_voxel
) # (M1 + M2 + ..., total_voxels, 3)
voxel_features = self.local_interpolate_module.forward(
support_xyz=xyz, support_features=features, xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz, new_xyz_grid_centers=voxel_centers, new_xyz_batch_cnt=new_xyz_batch_cnt
) # ((M1 + M2 ...) * total_voxels, C)
voxel_features = voxel_features.contiguous().view(-1, self.total_voxels * voxel_features.shape[-1])
return voxel_features
def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features, **kwargs):
"""
:param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features
:param xyz_batch_cnt: (batch_size), [N1, N2, ...]
:param new_xyz: (M1 + M2 ..., 3)
:param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
:param features: (N1 + N2 ..., C) tensor of the descriptors of the the features
:return:
new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz
new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
N, C = features.shape
assert C % self.num_reduced_channels == 0, \
f'the input channels ({C}) should be an integral multiple of num_reduced_channels({self.num_reduced_channels})'
features = features.view(N, -1, self.num_reduced_channels).sum(dim=1)
if self.local_aggregation_type in ['voxel_avg_pool', 'voxel_random_choice']:
vector_features, point_cnt_of_grid = self.vector_pool_with_voxel_query(
xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features,
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
)
elif self.local_aggregation_type == 'local_interpolation':
vector_features = self.vector_pool_with_local_interpolate(
xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features,
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
) # (M1 + M2 + ..., total_voxels * C)
else:
raise NotImplementedError
vector_features = vector_features.permute(1, 0)[None, :, :] # (1, num_voxels * C, M1 + M2 ...)
new_features = self.separate_local_aggregation_layer(vector_features)
new_features = self.post_mlps(new_features)
new_features = new_features.squeeze(dim=0).permute(1, 0)
return new_xyz, new_features
class VectorPoolAggregationModuleMSG(nn.Module):
def __init__(self, input_channels, config):
super().__init__()
self.model_cfg = config
self.num_groups = self.model_cfg.NUM_GROUPS
self.layers = []
c_in = 0
for k in range(self.num_groups):
cur_config = self.model_cfg[f'GROUP_CFG_{k}']
cur_vector_pool_module = VectorPoolAggregationModule(
input_channels=input_channels, num_local_voxel=cur_config.NUM_LOCAL_VOXEL,
post_mlps=cur_config.POST_MLPS,
max_neighbor_distance=cur_config.MAX_NEIGHBOR_DISTANCE,
neighbor_nsample=cur_config.NEIGHBOR_NSAMPLE,
local_aggregation_type=self.model_cfg.LOCAL_AGGREGATION_TYPE,
num_reduced_channels=self.model_cfg.get('NUM_REDUCED_CHANNELS', None),
num_channels_of_local_aggregation=self.model_cfg.NUM_CHANNELS_OF_LOCAL_AGGREGATION,
neighbor_distance_multiplier=2.0
)
self.__setattr__(f'layer_{k}', cur_vector_pool_module)
c_in += cur_config.POST_MLPS[-1]
c_in += 3 # use_xyz
shared_mlps = []
for cur_num_c in self.model_cfg.MSG_POST_MLPS:
shared_mlps.extend([
nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False),
nn.BatchNorm1d(cur_num_c),
nn.ReLU()
])
c_in = cur_num_c
self.msg_post_mlps = nn.Sequential(*shared_mlps)
def forward(self, **kwargs):
features_list = []
for k in range(self.num_groups):
cur_xyz, cur_features = self.__getattr__(f'layer_{k}')(**kwargs)
features_list.append(cur_features)
features = torch.cat(features_list, dim=-1)
features = torch.cat((cur_xyz, features), dim=-1)
features = features.permute(1, 0)[None, :, :] # (1, C, N)
new_features = self.msg_post_mlps(features)
new_features = new_features.squeeze(dim=0).permute(1, 0) # (N, C)
return cur_xyz, new_features
......@@ -155,7 +155,7 @@ class QueryAndGroup(nn.Module):
return new_features, idx
class FurthestPointSampling(Function):
class FarthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int):
"""
......@@ -173,7 +173,7 @@ class FurthestPointSampling(Function):
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
pointnet2.farthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
@staticmethod
......@@ -181,7 +181,44 @@ class FurthestPointSampling(Function):
return None, None
furthest_point_sample = FurthestPointSampling.apply
farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class StackFarthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz, xyz_batch_cnt, npoint):
"""
Args:
ctx:
xyz: (N1 + N2 + ..., 3) where N > npoint
xyz_batch_cnt: [N1, N2, ...]
npoint: int, number of features in the sampled set
Returns:
output: (npoint.sum()) tensor containing the set,
npoint: (M1, M2, ...)
"""
assert xyz.is_contiguous() and xyz.shape[1] == 3
batch_size = xyz_batch_cnt.__len__()
if not isinstance(npoint, torch.Tensor):
if not isinstance(npoint, list):
npoint = [npoint for i in range(batch_size)]
npoint = torch.tensor(npoint, device=xyz.device).int()
N, _ = xyz.size()
temp = torch.cuda.FloatTensor(N).fill_(1e10)
output = torch.cuda.IntTensor(npoint.sum().item())
pointnet2.stack_farthest_point_sampling_wrapper(xyz, temp, xyz_batch_cnt, output, npoint)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
stack_farthest_point_sample = StackFarthestPointSampling.apply
class ThreeNN(Function):
......@@ -262,5 +299,154 @@ class ThreeInterpolate(Function):
three_interpolate = ThreeInterpolate.apply
class ThreeNNForVectorPoolByTwoStep(Function):
@staticmethod
def forward(ctx, support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt,
max_neighbour_distance, nsample, neighbor_type, avg_length_of_neighbor_idxs, num_total_grids,
neighbor_distance_multiplier):
"""
Args:
ctx:
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// nsample: find all (-1), find limited number(>0)
// neighbor_type: 1: ball, others: cube
// neighbor_distance_multiplier: query_distance = neighbor_distance_multiplier * max_neighbour_distance
Returns:
// new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn
// new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn
"""
num_new_xyz = new_xyz.shape[0]
new_xyz_grid_dist2 = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape)
new_xyz_grid_idxs = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape).int().fill_(-1)
while True:
num_max_sum_points = avg_length_of_neighbor_idxs * num_new_xyz
stack_neighbor_idxs = new_xyz_grid_idxs.new_zeros(num_max_sum_points)
start_len = new_xyz_grid_idxs.new_zeros(num_new_xyz, 2).int()
cumsum = new_xyz_grid_idxs.new_zeros(1)
pointnet2.query_stacked_local_neighbor_idxs_wrapper_stack(
support_xyz.contiguous(), xyz_batch_cnt.contiguous(),
new_xyz.contiguous(), new_xyz_batch_cnt.contiguous(),
stack_neighbor_idxs.contiguous(), start_len.contiguous(), cumsum,
avg_length_of_neighbor_idxs, max_neighbour_distance * neighbor_distance_multiplier,
nsample, neighbor_type
)
avg_length_of_neighbor_idxs = cumsum[0].item() // num_new_xyz + int(cumsum[0].item() % num_new_xyz > 0)
if cumsum[0] <= num_max_sum_points:
break
stack_neighbor_idxs = stack_neighbor_idxs[:cumsum[0]]
pointnet2.query_three_nn_by_stacked_local_idxs_wrapper_stack(
support_xyz, new_xyz, new_xyz_grid_centers, new_xyz_grid_idxs, new_xyz_grid_dist2,
stack_neighbor_idxs, start_len, num_new_xyz, num_total_grids
)
return torch.sqrt(new_xyz_grid_dist2), new_xyz_grid_idxs, torch.tensor(avg_length_of_neighbor_idxs)
three_nn_for_vector_pool_by_two_step = ThreeNNForVectorPoolByTwoStep.apply
class VectorPoolWithVoxelQuery(Function):
@staticmethod
def forward(ctx, support_xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor, support_features: torch.Tensor,
new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor, num_grid_x, num_grid_y, num_grid_z,
max_neighbour_distance, num_c_out_each_grid, use_xyz,
num_mean_points_per_grid=100, nsample=-1, neighbor_type=0, pooling_type=0):
"""
Args:
ctx:
support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
support_features: (N1 + N2 ..., C)
new_xyz: (M1 + M2 ..., 3) centers of new positions
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
num_grid_x: number of grids in each local area centered at new_xyz
num_grid_y:
num_grid_z:
max_neighbour_distance:
num_c_out_each_grid:
use_xyz:
neighbor_type: 1: ball, others: cube:
pooling_type: 0: avg_pool, 1: random choice
Returns:
new_features: (M1 + M2 ..., num_c_out)
"""
assert support_xyz.is_contiguous()
assert support_features.is_contiguous()
assert xyz_batch_cnt.is_contiguous()
assert new_xyz.is_contiguous()
assert new_xyz_batch_cnt.is_contiguous()
num_total_grids = num_grid_x * num_grid_y * num_grid_z
num_c_out = num_c_out_each_grid * num_total_grids
N, num_c_in = support_features.shape
M = new_xyz.shape[0]
assert num_c_in % num_c_out_each_grid == 0, \
f'the input channels ({num_c_in}) should be an integral multiple of num_c_out_each_grid({num_c_out_each_grid})'
while True:
new_features = support_features.new_zeros((M, num_c_out))
new_local_xyz = support_features.new_zeros((M, 3 * num_total_grids))
point_cnt_of_grid = xyz_batch_cnt.new_zeros((M, num_total_grids))
num_max_sum_points = num_mean_points_per_grid * M
grouped_idxs = xyz_batch_cnt.new_zeros((num_max_sum_points, 3))
num_cum_sum = pointnet2.vector_pool_wrapper(
support_xyz, xyz_batch_cnt, support_features, new_xyz, new_xyz_batch_cnt,
new_features, new_local_xyz, point_cnt_of_grid, grouped_idxs,
num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance, use_xyz,
num_max_sum_points, nsample, neighbor_type, pooling_type
)
num_mean_points_per_grid = num_cum_sum // M + int(num_cum_sum % M > 0)
if num_cum_sum <= num_max_sum_points:
break
grouped_idxs = grouped_idxs[:num_cum_sum]
normalizer = torch.clamp_min(point_cnt_of_grid[:, :, None].float(), min=1e-6)
new_features = (new_features.view(-1, num_total_grids, num_c_out_each_grid) / normalizer).view(-1, num_c_out)
if use_xyz:
new_local_xyz = (new_local_xyz.view(-1, num_total_grids, 3) / normalizer).view(-1, num_total_grids * 3)
num_mean_points_per_grid = torch.Tensor([num_mean_points_per_grid]).int()
nsample = torch.Tensor([nsample]).int()
ctx.vector_pool_for_backward = (point_cnt_of_grid, grouped_idxs, N, num_c_in)
ctx.mark_non_differentiable(new_local_xyz, num_mean_points_per_grid, nsample, point_cnt_of_grid)
return new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid
@staticmethod
def backward(ctx, grad_new_features: torch.Tensor, grad_local_xyz: torch.Tensor, grad_num_cum_sum, grad_point_cnt_of_grid):
"""
Args:
ctx:
grad_new_features: (M1 + M2 ..., num_c_out), num_c_out = num_c_out_each_grid * num_total_grids
Returns:
grad_support_features: (N1 + N2 ..., C_in)
"""
point_cnt_of_grid, grouped_idxs, N, num_c_in = ctx.vector_pool_for_backward
grad_support_features = grad_new_features.new_zeros((N, num_c_in))
pointnet2.vector_pool_grad_wrapper(
grad_new_features.contiguous(), point_cnt_of_grid, grouped_idxs,
grad_support_features
)
return None, None, grad_support_features, None, None, None, None, None, None, None, None, None, None, None, None
vector_pool_with_voxel_query_op = VectorPoolWithVoxelQuery.apply
if __name__ == '__main__':
pass
......@@ -6,13 +6,15 @@
#include "sampling_gpu.h"
#include "interpolate_gpu.h"
#include "voxel_query_gpu.h"
#include "vector_pool_gpu.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack");
m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("stack_farthest_point_sampling_wrapper", &stack_farthest_point_sampling_wrapper, "stack_farthest_point_sampling_wrapper");
m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack");
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack");
......@@ -20,4 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack");
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack");
m.def("query_stacked_local_neighbor_idxs_wrapper_stack", &query_stacked_local_neighbor_idxs_wrapper_stack, "query_stacked_local_neighbor_idxs_wrapper_stack");
m.def("query_three_nn_by_stacked_local_idxs_wrapper_stack", &query_three_nn_by_stacked_local_idxs_wrapper_stack, "query_three_nn_by_stacked_local_idxs_wrapper_stack");
m.def("vector_pool_wrapper", &vector_pool_wrapper_stack, "vector_pool_grad_wrapper_stack");
m.def("vector_pool_grad_wrapper", &vector_pool_grad_wrapper_stack, "vector_pool_grad_wrapper_stack");
}
......@@ -21,7 +21,7 @@ extern THCState *state;
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int furthest_point_sampling_wrapper(int b, int n, int m,
int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(points_tensor);
......@@ -32,6 +32,29 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1;
}
int stack_farthest_point_sampling_wrapper(at::Tensor points_tensor,
at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor,
at::Tensor num_sampled_points_tensor) {
CHECK_INPUT(points_tensor);
CHECK_INPUT(temp_tensor);
CHECK_INPUT(idx_tensor);
CHECK_INPUT(xyz_batch_cnt_tensor);
CHECK_INPUT(num_sampled_points_tensor);
int batch_size = xyz_batch_cnt_tensor.size(0);
int N = points_tensor.size(0);
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *xyz_batch_cnt = xyz_batch_cnt_tensor.data<int>();
int *idx = idx_tensor.data<int>();
int *num_sampled_points = num_sampled_points_tensor.data<int>();
stack_farthest_point_sampling_kernel_launcher(N, batch_size, points, temp, xyz_batch_cnt, idx, num_sampled_points);
return 1;
}
\ No newline at end of file
......@@ -22,7 +22,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
__global__ void farthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
......@@ -139,7 +139,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
}
}
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
......@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch (n_threads) {
case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 512:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 256:
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 128:
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 64:
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 32:
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 16:
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 8:
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 4:
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 2:
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 1:
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
default:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
......@@ -182,3 +182,168 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
exit(-1);
}
}
template <unsigned int block_size>
__global__ void stack_farthest_point_sampling_kernel(int batch_size, int N,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points) {
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// num_sampled_points: [M1, M2, ...] int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int bs_idx = blockIdx.x;
int xyz_batch_start_idx = 0, idxs_start_idx = 0;
for (int k = 0; k < bs_idx; k++){
xyz_batch_start_idx += xyz_batch_cnt[k];
idxs_start_idx += num_sampled_points[k];
}
dataset += xyz_batch_start_idx * 3;
temp += xyz_batch_start_idx;
idxs += idxs_start_idx;
int n = xyz_batch_cnt[bs_idx];
int m = num_sampled_points[bs_idx];
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0) idxs[0] = xyz_batch_start_idx;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
}
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0)
idxs[j] = old + xyz_batch_start_idx;
}
}
void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points) {
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// npoint: int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
cudaError_t err;
unsigned int n_threads = opt_n_threads(N);
stack_farthest_point_sampling_kernel<1024><<<batch_size, 1024>>>(
batch_size, N, dataset, temp, xyz_batch_cnt, idxs, num_sampled_points
);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
\ No newline at end of file
......@@ -6,10 +6,18 @@
#include<vector>
int furthest_point_sampling_wrapper(int b, int n, int m,
int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs);
int stack_farthest_point_sampling_wrapper(
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor,
at::Tensor idx_tensor, at::Tensor num_sampled_points_tensor);
void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points);
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment