"docs/vscode:/vscode.git/clone" did not exist on "8ece4f3c0b113b655781e0024dadc82e888a2e9e"
Unverified Commit a991105c authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

Release the codes of PV-RCNN++, update OpenPCDet to v0.5.2

Release the codes of PV-RCNN++, update OpenPCDet to v0.5.2
parents 1483517a b6fbf07f
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
`OpenPCDet` is a clear, simple, self-contained open source project for LiDAR-based 3D object detection. `OpenPCDet` is a clear, simple, self-contained open source project for LiDAR-based 3D object detection.
It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/1812.04244), [`[Part-A2-Net]`](https://arxiv.org/abs/1907.03670), [`[PV-RCNN]`](https://arxiv.org/abs/1912.13192) and [`[Voxel R-CNN]`](https://arxiv.org/abs/2012.15712). It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/1812.04244), [`[Part-A2-Net]`](https://arxiv.org/abs/1907.03670), [`[PV-RCNN]`](https://arxiv.org/abs/1912.13192), [`[Voxel R-CNN]`](https://arxiv.org/abs/2012.15712) and [`[PV-RCNN++]`](https://arxiv.org/abs/2102.00463).
**NEW**: `OpenPCDet` has been updated to `v0.5.0` (Dec. 2021). **Highlights**:
* `OpenPCDet` has been updated to `v0.5.2` (Jan. 2022).
* The codes of PV-RCNN++ has been supported.
## Overview ## Overview
- [Changelog](#changelog) - [Changelog](#changelog)
...@@ -19,6 +21,12 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18 ...@@ -19,6 +21,12 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
## Changelog ## Changelog
[2022-01-05] **NEW:** Update `OpenPCDet` to v0.5.2:
* The code of [PV-RCNN++](https://arxiv.org/abs/2102.00463) has been released to this repo, with higher performance, faster training/inference speed and less memory consumption than PV-RCNN.
* Add performance of several models trained with full training set of [Waymo Open Dataset](#waymo-open-dataset-baselines).
* Support Lyft dataset, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/720).
[2021-12-09] **NEW:** Update `OpenPCDet` to v0.5.1: [2021-12-09] **NEW:** Update `OpenPCDet` to v0.5.1:
* Add PointPillar related baseline configs/results on [Waymo Open Dataset](#waymo-open-dataset-baselines). * Add PointPillar related baseline configs/results on [Waymo Open Dataset](#waymo-open-dataset-baselines).
* Support Pandaset dataloader, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/396). * Support Pandaset dataloader, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/396).
...@@ -108,7 +116,7 @@ Contributions are also welcomed. ...@@ -108,7 +116,7 @@ Contributions are also welcomed.
### KITTI 3D Object Detection Baselines ### KITTI 3D Object Detection Baselines
Selected supported methods are shown in the below table. The results are the 3D detection performance of moderate difficulty on the *val* set of KITTI dataset. Selected supported methods are shown in the below table. The results are the 3D detection performance of moderate difficulty on the *val* set of KITTI dataset.
* All models are trained with 8 GTX 1080Ti GPUs and are available for download. * All LiDAR-based models are trained with 8 GTX 1080Ti GPUs and are available for download.
* The training time is measured with 8 TITAN XP GPUs and PyTorch 1.5. * The training time is measured with 8 TITAN XP GPUs and PyTorch 1.5.
| | training time | Car@R11 | Pedestrian@R11 | Cyclist@R11 | download | | | training time | Car@R11 | Pedestrian@R11 | Cyclist@R11 | download |
...@@ -129,7 +137,7 @@ Selected supported methods are shown in the below table. The results are the 3D ...@@ -129,7 +137,7 @@ Selected supported methods are shown in the below table. The results are the 3D
We provide the setting of [`DATA_CONFIG.SAMPLED_INTERVAL`](tools/cfgs/dataset_configs/waymo_dataset.yaml) on the Waymo Open Dataset (WOD) to subsample partial samples for training and evaluation, We provide the setting of [`DATA_CONFIG.SAMPLED_INTERVAL`](tools/cfgs/dataset_configs/waymo_dataset.yaml) on the Waymo Open Dataset (WOD) to subsample partial samples for training and evaluation,
so you could also play with WOD by setting a smaller `DATA_CONFIG.SAMPLED_INTERVAL` even if you only have limited GPU resources. so you could also play with WOD by setting a smaller `DATA_CONFIG.SAMPLED_INTERVAL` even if you only have limited GPU resources.
By default, all models are trained with **20% data (~32k frames)** of all the training samples on 8 GTX 1080Ti GPUs, and the results of each cell here are mAP/mAPH calculated by the official Waymo evaluation metrics on the **whole** validation set (version 1.2). By default, all models are trained with **a single frame** of **20% data (~32k frames)** of all the training samples on 8 GTX 1080Ti GPUs, and the results of each cell here are mAP/mAPH calculated by the official Waymo evaluation metrics on the **whole** validation set (version 1.2).
| Performance@(train with 20\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 | | Performance@(train with 20\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:| |---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
...@@ -141,6 +149,22 @@ By default, all models are trained with **20% data (~32k frames)** of all the tr ...@@ -141,6 +149,22 @@ By default, all models are trained with **20% data (~32k frames)** of all the tr
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 74.66/74.12 |65.82/65.32 |71.71/62.24 |62.46/54.06 |66.53/65.18 |64.05/62.75 | | [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 74.66/74.12 |65.82/65.32 |71.71/62.24 |62.46/54.06 |66.53/65.18 |64.05/62.75 |
| [PV-RCNN (AnchorHead)](tools/cfgs/waymo_models/pv_rcnn.yaml) | 75.41/74.74 |67.44/66.80 |71.98/61.24 |63.70/53.95 |65.88/64.25 |63.39/61.82 | | [PV-RCNN (AnchorHead)](tools/cfgs/waymo_models/pv_rcnn.yaml) | 75.41/74.74 |67.44/66.80 |71.98/61.24 |63.70/53.95 |65.88/64.25 |63.39/61.82 |
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 75.95/75.43 |68.02/67.54 |75.94/69.40 |67.66/61.62 |70.18/68.98 |67.73/66.57| | [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 75.95/75.43 |68.02/67.54 |75.94/69.40 |67.66/61.62 |70.18/68.98 |67.73/66.57|
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 77.82/77.32| 69.07/68.62| 77.99/71.36| 69.92/63.74| 71.80/70.71| 69.31/68.26|
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) |77.61/77.14| 69.18/68.75| 79.42/73.31| 70.88/65.21| 72.50/71.39| 69.84/68.77|
Here we also provide the performance of several models trained on the full training set (refer to the paper of [PV-RCNN++](https://arxiv.org/abs/2102.00463)):
| Performance@(train with 100\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
| [SECOND](tools/cfgs/waymo_models/second.yaml) | 72.27/71.69 | 63.85/63.33 | 68.70/58.18 | 60.72/51.31 | 60.62/59.28 | 58.34/57.05|
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 77.05/76.51 | 68.47/67.97 | 75.24/66.87 | 66.18/58.62 | 68.60/67.36 | 66.13/64.93 |
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79|
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) |79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19|
We could not provide the above pretrained models due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), We could not provide the above pretrained models due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/),
but you could easily achieve similar performance by training with the default configs. but you could easily achieve similar performance by training with the default configs.
......
...@@ -86,7 +86,7 @@ OpenPCDet ...@@ -86,7 +86,7 @@ OpenPCDet
```shell script ```shell script
pip3 install --upgrade pip pip3 install --upgrade pip
# tf 2.0.0 # tf 2.0.0
pip3 install waymo-open-dataset-tf-2-0-0==1.2.0 --user pip3 install waymo-open-dataset-tf-2-5-0 --user
``` ```
* Extract point cloud data from tfrecord and generate data infos by running the following command (it takes several hours, * Extract point cloud data from tfrecord and generate data infos by running the following command (it takes several hours,
......
import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y): ...@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y):
return ans return ans
def sample_points_with_roi(rois, points, sample_radius_with_roi, num_max_points_of_part=200000):
"""
Args:
rois: (M, 7 + C)
points: (N, 3)
sample_radius_with_roi:
num_max_points_of_part:
Returns:
sampled_points: (N_out, 3)
"""
if points.shape[0] < num_max_points_of_part:
distance = (points[:, None, :] - rois[None, :, 0:3]).norm(dim=-1)
min_dis, min_dis_roi_idx = distance.min(dim=-1)
roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1)
point_mask = min_dis < roi_max_dim + sample_radius_with_roi
else:
start_idx = 0
point_mask_list = []
while start_idx < points.shape[0]:
distance = (points[start_idx:start_idx + num_max_points_of_part, None, :] - rois[None, :, 0:3]).norm(dim=-1)
min_dis, min_dis_roi_idx = distance.min(dim=-1)
roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1)
cur_point_mask = min_dis < roi_max_dim + sample_radius_with_roi
point_mask_list.append(cur_point_mask)
start_idx += num_max_points_of_part
point_mask = torch.cat(point_mask_list, dim=0)
sampled_points = points[:1] if point_mask.sum() == 0 else points[point_mask, :]
return sampled_points, point_mask
def sector_fps(points, num_sampled_points, num_sectors):
"""
Args:
points: (N, 3)
num_sampled_points: int
num_sectors: int
Returns:
sampled_points: (N_out, 3)
"""
sector_size = np.pi * 2 / num_sectors
point_angles = torch.atan2(points[:, 1], points[:, 0]) + np.pi
sector_idx = (point_angles / sector_size).floor().clamp(min=0, max=num_sectors)
xyz_points_list = []
xyz_batch_cnt = []
num_sampled_points_list = []
for k in range(num_sectors):
mask = (sector_idx == k)
cur_num_points = mask.sum().item()
if cur_num_points > 0:
xyz_points_list.append(points[mask])
xyz_batch_cnt.append(cur_num_points)
ratio = cur_num_points / points.shape[0]
num_sampled_points_list.append(
min(cur_num_points, math.ceil(ratio * num_sampled_points))
)
if len(xyz_batch_cnt) == 0:
xyz_points_list.append(points)
xyz_batch_cnt.append(len(points))
num_sampled_points_list.append(num_sampled_points)
print(f'Warning: empty sector points detected in SectorFPS: points.shape={points.shape}')
xyz = torch.cat(xyz_points_list, dim=0)
xyz_batch_cnt = torch.tensor(xyz_batch_cnt, device=points.device).int()
sampled_points_batch_cnt = torch.tensor(num_sampled_points_list, device=points.device).int()
sampled_pt_idxs = pointnet2_stack_utils.stack_farthest_point_sample(
xyz.contiguous(), xyz_batch_cnt, sampled_points_batch_cnt
).long()
sampled_points = xyz[sampled_pt_idxs]
return sampled_points
class VoxelSetAbstraction(nn.Module): class VoxelSetAbstraction(nn.Module):
def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None, def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None,
num_rawpoint_features=None, **kwargs): num_rawpoint_features=None, **kwargs):
...@@ -58,38 +139,31 @@ class VoxelSetAbstraction(nn.Module): ...@@ -58,38 +139,31 @@ class VoxelSetAbstraction(nn.Module):
if src_name in ['bev', 'raw_points']: if src_name in ['bev', 'raw_points']:
continue continue
self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR
mlps = SA_cfg[src_name].MLPS
for k in range(len(mlps)): if SA_cfg[src_name].get('INPUT_CHANNELS', None) is None:
mlps[k] = [mlps[k][0]] + mlps[k] input_channels = SA_cfg[src_name].MLPS[0][0] \
cur_layer = pointnet2_stack_modules.StackSAModuleMSG( if isinstance(SA_cfg[src_name].MLPS[0], list) else SA_cfg[src_name].MLPS[0]
radii=SA_cfg[src_name].POOL_RADIUS, else:
nsamples=SA_cfg[src_name].NSAMPLE, input_channels = SA_cfg[src_name]['INPUT_CHANNELS']
mlps=mlps,
use_xyz=True, cur_layer, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
pool_method='max_pool', input_channels=input_channels, config=SA_cfg[src_name]
) )
self.SA_layers.append(cur_layer) self.SA_layers.append(cur_layer)
self.SA_layer_names.append(src_name) self.SA_layer_names.append(src_name)
c_in += sum([x[-1] for x in mlps]) c_in += cur_num_c_out
if 'bev' in self.model_cfg.FEATURES_SOURCE: if 'bev' in self.model_cfg.FEATURES_SOURCE:
c_bev = num_bev_features c_bev = num_bev_features
c_in += c_bev c_in += c_bev
if 'raw_points' in self.model_cfg.FEATURES_SOURCE: if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
mlps = SA_cfg['raw_points'].MLPS self.SA_rawpoints, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
for k in range(len(mlps)): input_channels=num_rawpoint_features - 3, config=SA_cfg['raw_points']
mlps[k] = [num_rawpoint_features - 3] + mlps[k]
self.SA_rawpoints = pointnet2_stack_modules.StackSAModuleMSG(
radii=SA_cfg['raw_points'].POOL_RADIUS,
nsamples=SA_cfg['raw_points'].NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method='max_pool'
) )
c_in += sum([x[-1] for x in mlps])
c_in += cur_num_c_out
self.vsa_point_feature_fusion = nn.Sequential( self.vsa_point_feature_fusion = nn.Sequential(
nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False), nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False),
...@@ -100,23 +174,64 @@ class VoxelSetAbstraction(nn.Module): ...@@ -100,23 +174,64 @@ class VoxelSetAbstraction(nn.Module):
self.num_point_features_before_fusion = c_in self.num_point_features_before_fusion = c_in
def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride): def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride):
x_idxs = (keypoints[:, :, 0] - self.point_cloud_range[0]) / self.voxel_size[0] """
y_idxs = (keypoints[:, :, 1] - self.point_cloud_range[1]) / self.voxel_size[1] Args:
keypoints: (N1 + N2 + ..., 4)
bev_features: (B, C, H, W)
batch_size:
bev_stride:
Returns:
point_bev_features: (N1 + N2 + ..., C)
"""
x_idxs = (keypoints[:, 1] - self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[:, 2] - self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_stride x_idxs = x_idxs / bev_stride
y_idxs = y_idxs / bev_stride y_idxs = y_idxs / bev_stride
point_bev_features_list = [] point_bev_features_list = []
for k in range(batch_size): for k in range(batch_size):
cur_x_idxs = x_idxs[k] bs_mask = (keypoints[:, 0] == k)
cur_y_idxs = y_idxs[k]
cur_x_idxs = x_idxs[bs_mask]
cur_y_idxs = y_idxs[bs_mask]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C) cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs) point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features.unsqueeze(dim=0)) point_bev_features_list.append(point_bev_features)
point_bev_features = torch.cat(point_bev_features_list, dim=0) # (B, N, C0) point_bev_features = torch.cat(point_bev_features_list, dim=0) # (N1 + N2 + ..., C)
return point_bev_features return point_bev_features
def sectorized_proposal_centric_sampling(self, roi_boxes, points):
"""
Args:
roi_boxes: (M, 7 + C)
points: (N, 3)
Returns:
sampled_points: (N_out, 3)
"""
sampled_points, _ = sample_points_with_roi(
rois=roi_boxes, points=points,
sample_radius_with_roi=self.model_cfg.SPC_SAMPLING.SAMPLE_RADIUS_WITH_ROI,
num_max_points_of_part=self.model_cfg.SPC_SAMPLING.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000)
)
sampled_points = sector_fps(
points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS,
num_sectors=self.model_cfg.SPC_SAMPLING.NUM_SECTORS
)
return sampled_points
def get_sampled_points(self, batch_dict): def get_sampled_points(self, batch_dict):
"""
Args:
batch_dict:
Returns:
keypoints: (N1 + N2 + ..., 4), where 4 indicates [bs_idx, x, y, z]
"""
batch_size = batch_dict['batch_size'] batch_size = batch_dict['batch_size']
if self.model_cfg.POINT_SOURCE == 'raw_points': if self.model_cfg.POINT_SOURCE == 'raw_points':
src_points = batch_dict['points'][:, 1:4] src_points = batch_dict['points'][:, 1:4]
...@@ -136,7 +251,7 @@ class VoxelSetAbstraction(nn.Module): ...@@ -136,7 +251,7 @@ class VoxelSetAbstraction(nn.Module):
bs_mask = (batch_indices == bs_idx) bs_mask = (batch_indices == bs_idx)
sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3) sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3)
if self.model_cfg.SAMPLE_METHOD == 'FPS': if self.model_cfg.SAMPLE_METHOD == 'FPS':
cur_pt_idxs = pointnet2_stack_utils.furthest_point_sample( cur_pt_idxs = pointnet2_stack_utils.farthest_point_sample(
sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS
).long() ).long()
...@@ -147,16 +262,75 @@ class VoxelSetAbstraction(nn.Module): ...@@ -147,16 +262,75 @@ class VoxelSetAbstraction(nn.Module):
keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0) keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0)
elif self.model_cfg.SAMPLE_METHOD == 'FastFPS': elif self.model_cfg.SAMPLE_METHOD == 'SPC':
raise NotImplementedError cur_keypoints = self.sectorized_proposal_centric_sampling(
roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points[0]
)
bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx
keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1)
else: else:
raise NotImplementedError raise NotImplementedError
keypoints_list.append(keypoints) keypoints_list.append(keypoints)
keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) or (N1 + N2 + ..., 4)
if len(keypoints.shape) == 3:
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1, 1)
keypoints = torch.cat((batch_idx.float(), keypoints.view(-1, 3)), dim=1)
return keypoints return keypoints
@staticmethod
def aggregate_keypoint_features_from_one_source(
batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt,
filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=200000, rois=None
):
"""
Args:
aggregate_func:
xyz: (N, 3)
xyz_features: (N, C)
xyz_bs_idxs: (N)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size), [N1, N2, ...]
filter_neighbors_with_roi: True/False
radius_of_neighbor: float
num_max_points_of_part: int
rois: (batch_size, num_rois, 7 + C)
Returns:
"""
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
if filter_neighbors_with_roi:
point_features = torch.cat((xyz, xyz_features), dim=-1) if xyz_features is not None else xyz
point_features_list = []
for bs_idx in range(batch_size):
bs_mask = (xyz_bs_idxs == bs_idx)
_, valid_mask = sample_points_with_roi(
rois=rois[bs_idx], points=xyz[bs_mask],
sample_radius_with_roi=radius_of_neighbor, num_max_points_of_part=num_max_points_of_part,
)
point_features_list.append(point_features[bs_mask][valid_mask])
xyz_batch_cnt[bs_idx] = valid_mask.sum()
valid_point_features = torch.cat(point_features_list, dim=0)
xyz = valid_point_features[:, 0:3]
xyz_features = valid_point_features[:, 3:] if xyz_features is not None else None
else:
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (xyz_bs_idxs == bs_idx).sum()
pooled_points, pooled_features = aggregate_func(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=xyz_features.contiguous(),
)
return pooled_features
def forward(self, batch_dict): def forward(self, batch_dict):
""" """
Args: Args:
...@@ -185,56 +359,53 @@ class VoxelSetAbstraction(nn.Module): ...@@ -185,56 +359,53 @@ class VoxelSetAbstraction(nn.Module):
) )
point_features_list.append(point_bev_features) point_features_list.append(point_bev_features)
batch_size, num_keypoints, _ = keypoints.shape batch_size = batch_dict['batch_size']
new_xyz = keypoints.view(-1, 3)
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int().fill_(num_keypoints) new_xyz = keypoints[:, 1:4].contiguous()
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int()
for k in range(batch_size):
new_xyz_batch_cnt[k] = (keypoints[:, 0] == k).sum()
if 'raw_points' in self.model_cfg.FEATURES_SOURCE: if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
raw_points = batch_dict['points'] raw_points = batch_dict['points']
xyz = raw_points[:, 1:4]
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (raw_points[:, 0] == bs_idx).sum()
point_features = raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None
pooled_points, pooled_features = self.SA_rawpoints( pooled_features = self.aggregate_keypoint_features_from_one_source(
xyz=xyz.contiguous(), batch_size=batch_size, aggregate_func=self.SA_rawpoints,
xyz_batch_cnt=xyz_batch_cnt, xyz=raw_points[:, 1:4],
new_xyz=new_xyz, xyz_features=raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None,
new_xyz_batch_cnt=new_xyz_batch_cnt, xyz_bs_idxs=raw_points[:, 0],
features=point_features, new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt,
filter_neighbors_with_roi=self.model_cfg.SA_LAYER['raw_points'].get('FILTER_NEIGHBOR_WITH_ROI', False),
radius_of_neighbor=self.model_cfg.SA_LAYER['raw_points'].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None),
rois=batch_dict.get('rois', None)
) )
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1)) point_features_list.append(pooled_features)
for k, src_name in enumerate(self.SA_layer_names): for k, src_name in enumerate(self.SA_layer_names):
cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices
cur_features = batch_dict['multi_scale_3d_features'][src_name].features.contiguous()
xyz = common_utils.get_voxel_centers( xyz = common_utils.get_voxel_centers(
cur_coords[:, 1:4], cur_coords[:, 1:4], downsample_times=self.downsample_times_map[src_name],
downsample_times=self.downsample_times_map[src_name], voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range
voxel_size=self.voxel_size,
point_cloud_range=self.point_cloud_range
) )
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum()
pooled_points, pooled_features = self.SA_layers[k]( pooled_features = self.aggregate_keypoint_features_from_one_source(
xyz=xyz.contiguous(), batch_size=batch_size, aggregate_func=self.SA_layers[k],
xyz_batch_cnt=xyz_batch_cnt, xyz=xyz.contiguous(), xyz_features=cur_features, xyz_bs_idxs=cur_coords[:, 0],
new_xyz=new_xyz, new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt,
new_xyz_batch_cnt=new_xyz_batch_cnt, filter_neighbors_with_roi=self.model_cfg.SA_LAYER[src_name].get('FILTER_NEIGHBOR_WITH_ROI', False),
features=batch_dict['multi_scale_3d_features'][src_name].features.contiguous(), radius_of_neighbor=self.model_cfg.SA_LAYER[src_name].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None),
rois=batch_dict.get('rois', None)
) )
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1))
point_features = torch.cat(point_features_list, dim=2) point_features_list.append(pooled_features)
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1) point_features = torch.cat(point_features_list, dim=-1)
point_coords = torch.cat((batch_idx.view(-1, 1).float(), keypoints.view(-1, 3)), dim=1)
batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1]) batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1])
point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1])) point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1]))
batch_dict['point_features'] = point_features # (BxN, C) batch_dict['point_features'] = point_features # (BxN, C)
batch_dict['point_coords'] = point_coords # (BxN, 4) batch_dict['point_coords'] = keypoints # (BxN, 4)
return batch_dict return batch_dict
...@@ -174,7 +174,7 @@ class PointNet2Backbone(nn.Module): ...@@ -174,7 +174,7 @@ class PointNet2Backbone(nn.Module):
else: else:
last_num_points = self.num_points_each_layer[i - 1] last_num_points = self.num_points_each_layer[i - 1]
cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points] cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points]
cur_pt_idxs = pointnet2_utils_stack.furthest_point_sample( cur_pt_idxs = pointnet2_utils_stack.farthest_point_sample(
cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i] cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i]
).long()[0] ).long()[0]
if cur_xyz.shape[0] < self.num_points_each_layer[i]: if cur_xyz.shape[0] < self.num_points_each_layer[i]:
......
...@@ -8,6 +8,7 @@ from .second_net_iou import SECONDNetIoU ...@@ -8,6 +8,7 @@ from .second_net_iou import SECONDNetIoU
from .caddn import CaDDN from .caddn import CaDDN
from .voxel_rcnn import VoxelRCNN from .voxel_rcnn import VoxelRCNN
from .centerpoint import CenterPoint from .centerpoint import CenterPoint
from .pv_rcnn_plusplus import PVRCNNPlusPlus
__all__ = { __all__ = {
'Detector3DTemplate': Detector3DTemplate, 'Detector3DTemplate': Detector3DTemplate,
...@@ -19,7 +20,8 @@ __all__ = { ...@@ -19,7 +20,8 @@ __all__ = {
'SECONDNetIoU': SECONDNetIoU, 'SECONDNetIoU': SECONDNetIoU,
'CaDDN': CaDDN, 'CaDDN': CaDDN,
'VoxelRCNN': VoxelRCNN, 'VoxelRCNN': VoxelRCNN,
'CenterPoint': CenterPoint 'CenterPoint': CenterPoint,
'PVRCNNPlusPlus': PVRCNNPlusPlus
} }
......
from .detector3d_template import Detector3DTemplate
class PVRCNNPlusPlus(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
batch_dict = self.vfe(batch_dict)
batch_dict = self.backbone_3d(batch_dict)
batch_dict = self.map_to_bev_module(batch_dict)
batch_dict = self.backbone_2d(batch_dict)
batch_dict = self.dense_head(batch_dict)
batch_dict = self.roi_head.proposal_layer(
batch_dict, nms_config=self.roi_head.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.roi_head.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
batch_dict['roi_targets_dict'] = targets_dict
num_rois_per_scene = targets_dict['rois'].shape[1]
if 'roi_valid_num' in batch_dict:
batch_dict['roi_valid_num'] = [num_rois_per_scene for _ in range(batch_dict['batch_size'])]
batch_dict = self.pfe(batch_dict)
batch_dict = self.point_head(batch_dict)
batch_dict = self.roi_head(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
if self.point_head is not None:
loss_point, tb_dict = self.point_head.get_loss(tb_dict)
else:
loss_point = 0
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss_rpn + loss_point + loss_rcnn
return loss, tb_dict, disp_dict
...@@ -10,21 +10,12 @@ class PVRCNNHead(RoIHeadTemplate): ...@@ -10,21 +10,12 @@ class PVRCNNHead(RoIHeadTemplate):
super().__init__(num_class=num_class, model_cfg=model_cfg) super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg self.model_cfg = model_cfg
mlps = self.model_cfg.ROI_GRID_POOL.MLPS self.roi_grid_pool_layer, num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
for k in range(len(mlps)): input_channels=input_channels, config=self.model_cfg.ROI_GRID_POOL
mlps[k] = [input_channels] + mlps[k]
self.roi_grid_pool_layer = pointnet2_stack_modules.StackSAModuleMSG(
radii=self.model_cfg.ROI_GRID_POOL.POOL_RADIUS,
nsamples=self.model_cfg.ROI_GRID_POOL.NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method=self.model_cfg.ROI_GRID_POOL.POOL_METHOD,
) )
GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
c_out = sum([x[-1] for x in mlps]) pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * num_c_out
pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * c_out
shared_fc_list = [] shared_fc_list = []
for k in range(0, self.model_cfg.SHARED_FC.__len__()): for k in range(0, self.model_cfg.SHARED_FC.__len__()):
...@@ -150,6 +141,8 @@ class PVRCNNHead(RoIHeadTemplate): ...@@ -150,6 +141,8 @@ class PVRCNNHead(RoIHeadTemplate):
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST'] batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
) )
if self.training: if self.training:
targets_dict = batch_dict.get('roi_targets_dict', None)
if targets_dict is None:
targets_dict = self.assign_targets(batch_dict) targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois'] batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels'] batch_dict['roi_labels'] = targets_dict['roi_labels']
......
...@@ -31,7 +31,7 @@ class _PointnetSAModuleBase(nn.Module): ...@@ -31,7 +31,7 @@ class _PointnetSAModuleBase(nn.Module):
if new_xyz is None: if new_xyz is None:
new_xyz = pointnet2_utils.gather_operation( new_xyz = pointnet2_utils.gather_operation(
xyz_flipped, xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint) pointnet2_utils.farthest_point_sample(xyz, self.npoint)
).transpose(1, 2).contiguous() if self.npoint is not None else None ).transpose(1, 2).contiguous() if self.npoint is not None else None
for i in range(len(self.groupers)): for i in range(len(self.groupers)):
......
...@@ -7,11 +7,11 @@ from torch.autograd import Function, Variable ...@@ -7,11 +7,11 @@ from torch.autograd import Function, Variable
from . import pointnet2_batch_cuda as pointnet2 from . import pointnet2_batch_cuda as pointnet2
class FurthestPointSampling(Function): class FarthestPointSampling(Function):
@staticmethod @staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
""" """
Uses iterative furthest point sampling to select a set of npoint features that have the largest Uses iterative farthest point sampling to select a set of npoint features that have the largest
minimum distance minimum distance
:param ctx: :param ctx:
:param xyz: (B, N, 3) where N > npoint :param xyz: (B, N, 3) where N > npoint
...@@ -25,7 +25,7 @@ class FurthestPointSampling(Function): ...@@ -25,7 +25,7 @@ class FurthestPointSampling(Function):
output = torch.cuda.IntTensor(B, npoint) output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10) temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) pointnet2.farthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output return output
@staticmethod @staticmethod
...@@ -33,7 +33,7 @@ class FurthestPointSampling(Function): ...@@ -33,7 +33,7 @@ class FurthestPointSampling(Function):
return None, None return None, None
furthest_point_sample = FurthestPointSampling.apply farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class GatherOperation(Function): class GatherOperation(Function):
......
...@@ -16,7 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -16,7 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
......
...@@ -38,13 +38,13 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, ...@@ -38,13 +38,13 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
} }
int furthest_point_sampling_wrapper(int b, int n, int m, int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
const float *points = points_tensor.data<float>(); const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>(); float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>(); int *idx = idx_tensor.data<int>();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx); farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1; return 1;
} }
...@@ -98,7 +98,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i ...@@ -98,7 +98,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
} }
template <unsigned int block_size> template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m, __global__ void farthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3) // dataset: (B, N, 3)
// tmp: (B, N) // tmp: (B, N)
...@@ -215,7 +215,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, ...@@ -215,7 +215,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
} }
} }
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs) { const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, 3) // dataset: (B, N, 3)
// tmp: (B, N) // tmp: (B, N)
...@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m, ...@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch (n_threads) { switch (n_threads) {
case 1024: case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 512: case 512:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 256: case 256:
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 128: case 128:
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 64: case 64:
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 32: case 32:
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 16: case 16:
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 8: case 8:
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 4: case 4:
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 2: case 2:
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 1: case 1:
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
default: default:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
} }
err = cudaGetLastError(); err = cudaGetLastError();
......
...@@ -20,10 +20,10 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, ...@@ -20,10 +20,10 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, float *grad_points); const float *grad_out, const int *idx, float *grad_points);
int furthest_point_sampling_wrapper(int b, int n, int m, int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs); const float *dataset, float *temp, int *idxs);
#endif #endif
...@@ -7,6 +7,26 @@ import torch.nn.functional as F ...@@ -7,6 +7,26 @@ import torch.nn.functional as F
from . import pointnet2_utils from . import pointnet2_utils
def build_local_aggregation_module(input_channels, config):
local_aggregation_name = config.get('NAME', 'StackSAModuleMSG')
if local_aggregation_name == 'StackSAModuleMSG':
mlps = config.MLPS
for k in range(len(mlps)):
mlps[k] = [input_channels] + mlps[k]
cur_layer = StackSAModuleMSG(
radii=config.POOL_RADIUS, nsamples=config.NSAMPLE, mlps=mlps, use_xyz=True, pool_method='max_pool',
)
num_c_out = sum([x[-1] for x in mlps])
elif local_aggregation_name == 'VectorPoolAggregationModuleMSG':
cur_layer = VectorPoolAggregationModuleMSG(input_channels=input_channels, config=config)
num_c_out = config.MSG_POST_MLPS[-1]
else:
raise NotImplementedError
return cur_layer, num_c_out
class StackSAModuleMSG(nn.Module): class StackSAModuleMSG(nn.Module):
def __init__(self, *, radii: List[float], nsamples: List[int], mlps: List[List[int]], def __init__(self, *, radii: List[float], nsamples: List[int], mlps: List[List[int]],
...@@ -135,3 +155,316 @@ class StackPointnetFPModule(nn.Module): ...@@ -135,3 +155,316 @@ class StackPointnetFPModule(nn.Module):
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C) new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features return new_features
class VectorPoolLocalInterpolateModule(nn.Module):
def __init__(self, mlp, num_voxels, max_neighbour_distance, nsample, neighbor_type, use_xyz=True,
neighbour_distance_multiplier=1.0, xyz_encoding_type='concat'):
"""
Args:
mlp:
num_voxels:
max_neighbour_distance:
neighbor_type: 1: ball, others: cube
nsample: find all (-1), find limited number(>0)
use_xyz:
neighbour_distance_multiplier:
xyz_encoding_type:
"""
super().__init__()
self.num_voxels = num_voxels # [num_grid_x, num_grid_y, num_grid_z]: number of grids in each local area centered at new_xyz
self.num_total_grids = self.num_voxels[0] * self.num_voxels[1] * self.num_voxels[2]
self.max_neighbour_distance = max_neighbour_distance
self.neighbor_distance_multiplier = neighbour_distance_multiplier
self.nsample = nsample
self.neighbor_type = neighbor_type
self.use_xyz = use_xyz
self.xyz_encoding_type = xyz_encoding_type
if mlp is not None:
if self.use_xyz:
mlp[0] += 9 if self.xyz_encoding_type == 'concat' else 0
shared_mlps = []
for k in range(len(mlp) - 1):
shared_mlps.extend([
nn.Conv2d(mlp[k], mlp[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp[k + 1]),
nn.ReLU()
])
self.mlp = nn.Sequential(*shared_mlps)
else:
self.mlp = None
self.num_avg_length_of_neighbor_idxs = 1000
def forward(self, support_xyz, support_features, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt):
"""
Args:
support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
support_features: (N1 + N2 ..., C) point-wise features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
new_xyz: (M1 + M2 ..., 3) centers of the ball query
new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
Returns:
new_features: (N1 + N2 ..., C_out)
"""
with torch.no_grad():
dist, idx, num_avg_length_of_neighbor_idxs = pointnet2_utils.three_nn_for_vector_pool_by_two_step(
support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt,
self.max_neighbour_distance, self.nsample, self.neighbor_type,
self.num_avg_length_of_neighbor_idxs, self.num_total_grids, self.neighbor_distance_multiplier
)
self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs.item())
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=-1, keepdim=True)
weight = dist_recip / torch.clamp_min(norm, min=1e-8)
empty_mask = (idx.view(-1, 3)[:, 0] == -1)
idx.view(-1, 3)[empty_mask] = 0
interpolated_feats = pointnet2_utils.three_interpolate(support_features, idx.view(-1, 3), weight.view(-1, 3))
interpolated_feats = interpolated_feats.view(idx.shape[0], idx.shape[1], -1) # (M1 + M2 ..., num_total_grids, C)
if self.use_xyz:
near_known_xyz = support_xyz[idx.view(-1, 3).long()].view(-1, 3, 3) # ( (M1 + M2 ...)*num_total_grids, 3)
local_xyz = (new_xyz_grid_centers.view(-1, 1, 3) - near_known_xyz).view(-1, idx.shape[1], 9)
if self.xyz_encoding_type == 'concat':
interpolated_feats = torch.cat((interpolated_feats, local_xyz), dim=-1) # ( M1 + M2 ..., num_total_grids, 9+C)
else:
raise NotImplementedError
new_features = interpolated_feats.view(-1, interpolated_feats.shape[-1]) # ((M1 + M2 ...) * num_total_grids, C)
new_features[empty_mask, :] = 0
if self.mlp is not None:
new_features = new_features.permute(1, 0)[None, :, :, None] # (1, C, N1 + N2 ..., 1)
new_features = self.mlp(new_features)
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features
class VectorPoolAggregationModule(nn.Module):
def __init__(
self, input_channels, num_local_voxel=(3, 3, 3), local_aggregation_type='local_interpolation',
num_reduced_channels=30, num_channels_of_local_aggregation=32, post_mlps=(128,),
max_neighbor_distance=None, neighbor_nsample=-1, neighbor_type=0, neighbor_distance_multiplier=2.0):
super().__init__()
self.num_local_voxel = num_local_voxel
self.total_voxels = self.num_local_voxel[0] * self.num_local_voxel[1] * self.num_local_voxel[2]
self.local_aggregation_type = local_aggregation_type
assert self.local_aggregation_type in ['local_interpolation', 'voxel_avg_pool', 'voxel_random_choice']
self.input_channels = input_channels
self.num_reduced_channels = input_channels if num_reduced_channels is None else num_reduced_channels
self.num_channels_of_local_aggregation = num_channels_of_local_aggregation
self.max_neighbour_distance = max_neighbor_distance
self.neighbor_nsample = neighbor_nsample
self.neighbor_type = neighbor_type # 1: ball, others: cube
if self.local_aggregation_type == 'local_interpolation':
self.local_interpolate_module = VectorPoolLocalInterpolateModule(
mlp=None, num_voxels=self.num_local_voxel,
max_neighbour_distance=self.max_neighbour_distance,
nsample=self.neighbor_nsample,
neighbor_type=self.neighbor_type,
neighbour_distance_multiplier=neighbor_distance_multiplier,
)
num_c_in = (self.num_reduced_channels + 9) * self.total_voxels
else:
self.local_interpolate_module = None
num_c_in = (self.num_reduced_channels + 3) * self.total_voxels
num_c_out = self.total_voxels * self.num_channels_of_local_aggregation
self.separate_local_aggregation_layer = nn.Sequential(
nn.Conv1d(num_c_in, num_c_out, kernel_size=1, groups=self.total_voxels, bias=False),
nn.BatchNorm1d(num_c_out),
nn.ReLU()
)
post_mlp_list = []
c_in = num_c_out
for cur_num_c in post_mlps:
post_mlp_list.extend([
nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False),
nn.BatchNorm1d(cur_num_c),
nn.ReLU()
])
c_in = cur_num_c
self.post_mlps = nn.Sequential(*post_mlp_list)
self.num_mean_points_per_grid = 20
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def extra_repr(self) -> str:
ret = f'radius={self.max_neighbour_distance}, local_voxels=({self.num_local_voxel}, ' \
f'local_aggregation_type={self.local_aggregation_type}, ' \
f'num_c_reduction={self.input_channels}->{self.num_reduced_channels}, ' \
f'num_c_local_aggregation={self.num_channels_of_local_aggregation}'
return ret
def vector_pool_with_voxel_query(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt):
use_xyz = 1
pooling_type = 0 if self.local_aggregation_type == 'voxel_avg_pool' else 1
new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid = pointnet2_utils.vector_pool_with_voxel_query_op(
xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt,
self.num_local_voxel[0], self.num_local_voxel[1], self.num_local_voxel[2],
self.max_neighbour_distance, self.num_reduced_channels, use_xyz,
self.num_mean_points_per_grid, self.neighbor_nsample, self.neighbor_type,
pooling_type
)
self.num_mean_points_per_grid = max(self.num_mean_points_per_grid, num_mean_points_per_grid.item())
num_new_pts = new_features.shape[0]
new_local_xyz = new_local_xyz.view(num_new_pts, -1, 3) # (N, num_voxel, 3)
new_features = new_features.view(num_new_pts, -1, self.num_reduced_channels) # (N, num_voxel, C)
new_features = torch.cat((new_local_xyz, new_features), dim=-1).view(num_new_pts, -1)
return new_features, point_cnt_of_grid
@staticmethod
def get_dense_voxels_by_center(point_centers, max_neighbour_distance, num_voxels):
"""
Args:
point_centers: (N, 3)
max_neighbour_distance: float
num_voxels: [num_x, num_y, num_z]
Returns:
voxel_centers: (N, total_voxels, 3)
"""
R = max_neighbour_distance
device = point_centers.device
x_grids = torch.arange(-R + R / num_voxels[0], R - R / num_voxels[0] + 1e-5, 2 * R / num_voxels[0], device=device)
y_grids = torch.arange(-R + R / num_voxels[1], R - R / num_voxels[1] + 1e-5, 2 * R / num_voxels[1], device=device)
z_grids = torch.arange(-R + R / num_voxels[2], R - R / num_voxels[2] + 1e-5, 2 * R / num_voxels[2], device=device)
x_offset, y_offset, z_offset = torch.meshgrid(x_grids, y_grids, z_grids) # shape: [num_x, num_y, num_z]
xyz_offset = torch.cat((
x_offset.contiguous().view(-1, 1),
y_offset.contiguous().view(-1, 1),
z_offset.contiguous().view(-1, 1)), dim=-1
)
voxel_centers = point_centers[:, None, :] + xyz_offset[None, :, :]
return voxel_centers
def vector_pool_with_local_interpolate(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt):
"""
Args:
xyz: (N, 3)
xyz_batch_cnt: (batch_size)
features: (N, C)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size)
Returns:
new_features: (M, total_voxels * C)
"""
voxel_centers = self.get_dense_voxels_by_center(
point_centers=new_xyz, max_neighbour_distance=self.max_neighbour_distance, num_voxels=self.num_local_voxel
) # (M1 + M2 + ..., total_voxels, 3)
voxel_features = self.local_interpolate_module.forward(
support_xyz=xyz, support_features=features, xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz, new_xyz_grid_centers=voxel_centers, new_xyz_batch_cnt=new_xyz_batch_cnt
) # ((M1 + M2 ...) * total_voxels, C)
voxel_features = voxel_features.contiguous().view(-1, self.total_voxels * voxel_features.shape[-1])
return voxel_features
def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features, **kwargs):
"""
:param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features
:param xyz_batch_cnt: (batch_size), [N1, N2, ...]
:param new_xyz: (M1 + M2 ..., 3)
:param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
:param features: (N1 + N2 ..., C) tensor of the descriptors of the the features
:return:
new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz
new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
N, C = features.shape
assert C % self.num_reduced_channels == 0, \
f'the input channels ({C}) should be an integral multiple of num_reduced_channels({self.num_reduced_channels})'
features = features.view(N, -1, self.num_reduced_channels).sum(dim=1)
if self.local_aggregation_type in ['voxel_avg_pool', 'voxel_random_choice']:
vector_features, point_cnt_of_grid = self.vector_pool_with_voxel_query(
xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features,
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
)
elif self.local_aggregation_type == 'local_interpolation':
vector_features = self.vector_pool_with_local_interpolate(
xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features,
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
) # (M1 + M2 + ..., total_voxels * C)
else:
raise NotImplementedError
vector_features = vector_features.permute(1, 0)[None, :, :] # (1, num_voxels * C, M1 + M2 ...)
new_features = self.separate_local_aggregation_layer(vector_features)
new_features = self.post_mlps(new_features)
new_features = new_features.squeeze(dim=0).permute(1, 0)
return new_xyz, new_features
class VectorPoolAggregationModuleMSG(nn.Module):
def __init__(self, input_channels, config):
super().__init__()
self.model_cfg = config
self.num_groups = self.model_cfg.NUM_GROUPS
self.layers = []
c_in = 0
for k in range(self.num_groups):
cur_config = self.model_cfg[f'GROUP_CFG_{k}']
cur_vector_pool_module = VectorPoolAggregationModule(
input_channels=input_channels, num_local_voxel=cur_config.NUM_LOCAL_VOXEL,
post_mlps=cur_config.POST_MLPS,
max_neighbor_distance=cur_config.MAX_NEIGHBOR_DISTANCE,
neighbor_nsample=cur_config.NEIGHBOR_NSAMPLE,
local_aggregation_type=self.model_cfg.LOCAL_AGGREGATION_TYPE,
num_reduced_channels=self.model_cfg.get('NUM_REDUCED_CHANNELS', None),
num_channels_of_local_aggregation=self.model_cfg.NUM_CHANNELS_OF_LOCAL_AGGREGATION,
neighbor_distance_multiplier=2.0
)
self.__setattr__(f'layer_{k}', cur_vector_pool_module)
c_in += cur_config.POST_MLPS[-1]
c_in += 3 # use_xyz
shared_mlps = []
for cur_num_c in self.model_cfg.MSG_POST_MLPS:
shared_mlps.extend([
nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False),
nn.BatchNorm1d(cur_num_c),
nn.ReLU()
])
c_in = cur_num_c
self.msg_post_mlps = nn.Sequential(*shared_mlps)
def forward(self, **kwargs):
features_list = []
for k in range(self.num_groups):
cur_xyz, cur_features = self.__getattr__(f'layer_{k}')(**kwargs)
features_list.append(cur_features)
features = torch.cat(features_list, dim=-1)
features = torch.cat((cur_xyz, features), dim=-1)
features = features.permute(1, 0)[None, :, :] # (1, C, N)
new_features = self.msg_post_mlps(features)
new_features = new_features.squeeze(dim=0).permute(1, 0) # (N, C)
return cur_xyz, new_features
...@@ -155,7 +155,7 @@ class QueryAndGroup(nn.Module): ...@@ -155,7 +155,7 @@ class QueryAndGroup(nn.Module):
return new_features, idx return new_features, idx
class FurthestPointSampling(Function): class FarthestPointSampling(Function):
@staticmethod @staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int): def forward(ctx, xyz: torch.Tensor, npoint: int):
""" """
...@@ -173,7 +173,7 @@ class FurthestPointSampling(Function): ...@@ -173,7 +173,7 @@ class FurthestPointSampling(Function):
output = torch.cuda.IntTensor(B, npoint) output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10) temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) pointnet2.farthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output return output
@staticmethod @staticmethod
...@@ -181,7 +181,44 @@ class FurthestPointSampling(Function): ...@@ -181,7 +181,44 @@ class FurthestPointSampling(Function):
return None, None return None, None
furthest_point_sample = FurthestPointSampling.apply farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class StackFarthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz, xyz_batch_cnt, npoint):
"""
Args:
ctx:
xyz: (N1 + N2 + ..., 3) where N > npoint
xyz_batch_cnt: [N1, N2, ...]
npoint: int, number of features in the sampled set
Returns:
output: (npoint.sum()) tensor containing the set,
npoint: (M1, M2, ...)
"""
assert xyz.is_contiguous() and xyz.shape[1] == 3
batch_size = xyz_batch_cnt.__len__()
if not isinstance(npoint, torch.Tensor):
if not isinstance(npoint, list):
npoint = [npoint for i in range(batch_size)]
npoint = torch.tensor(npoint, device=xyz.device).int()
N, _ = xyz.size()
temp = torch.cuda.FloatTensor(N).fill_(1e10)
output = torch.cuda.IntTensor(npoint.sum().item())
pointnet2.stack_farthest_point_sampling_wrapper(xyz, temp, xyz_batch_cnt, output, npoint)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
stack_farthest_point_sample = StackFarthestPointSampling.apply
class ThreeNN(Function): class ThreeNN(Function):
...@@ -262,5 +299,154 @@ class ThreeInterpolate(Function): ...@@ -262,5 +299,154 @@ class ThreeInterpolate(Function):
three_interpolate = ThreeInterpolate.apply three_interpolate = ThreeInterpolate.apply
class ThreeNNForVectorPoolByTwoStep(Function):
@staticmethod
def forward(ctx, support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt,
max_neighbour_distance, nsample, neighbor_type, avg_length_of_neighbor_idxs, num_total_grids,
neighbor_distance_multiplier):
"""
Args:
ctx:
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// nsample: find all (-1), find limited number(>0)
// neighbor_type: 1: ball, others: cube
// neighbor_distance_multiplier: query_distance = neighbor_distance_multiplier * max_neighbour_distance
Returns:
// new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn
// new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn
"""
num_new_xyz = new_xyz.shape[0]
new_xyz_grid_dist2 = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape)
new_xyz_grid_idxs = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape).int().fill_(-1)
while True:
num_max_sum_points = avg_length_of_neighbor_idxs * num_new_xyz
stack_neighbor_idxs = new_xyz_grid_idxs.new_zeros(num_max_sum_points)
start_len = new_xyz_grid_idxs.new_zeros(num_new_xyz, 2).int()
cumsum = new_xyz_grid_idxs.new_zeros(1)
pointnet2.query_stacked_local_neighbor_idxs_wrapper_stack(
support_xyz.contiguous(), xyz_batch_cnt.contiguous(),
new_xyz.contiguous(), new_xyz_batch_cnt.contiguous(),
stack_neighbor_idxs.contiguous(), start_len.contiguous(), cumsum,
avg_length_of_neighbor_idxs, max_neighbour_distance * neighbor_distance_multiplier,
nsample, neighbor_type
)
avg_length_of_neighbor_idxs = cumsum[0].item() // num_new_xyz + int(cumsum[0].item() % num_new_xyz > 0)
if cumsum[0] <= num_max_sum_points:
break
stack_neighbor_idxs = stack_neighbor_idxs[:cumsum[0]]
pointnet2.query_three_nn_by_stacked_local_idxs_wrapper_stack(
support_xyz, new_xyz, new_xyz_grid_centers, new_xyz_grid_idxs, new_xyz_grid_dist2,
stack_neighbor_idxs, start_len, num_new_xyz, num_total_grids
)
return torch.sqrt(new_xyz_grid_dist2), new_xyz_grid_idxs, torch.tensor(avg_length_of_neighbor_idxs)
three_nn_for_vector_pool_by_two_step = ThreeNNForVectorPoolByTwoStep.apply
class VectorPoolWithVoxelQuery(Function):
@staticmethod
def forward(ctx, support_xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor, support_features: torch.Tensor,
new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor, num_grid_x, num_grid_y, num_grid_z,
max_neighbour_distance, num_c_out_each_grid, use_xyz,
num_mean_points_per_grid=100, nsample=-1, neighbor_type=0, pooling_type=0):
"""
Args:
ctx:
support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
support_features: (N1 + N2 ..., C)
new_xyz: (M1 + M2 ..., 3) centers of new positions
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
num_grid_x: number of grids in each local area centered at new_xyz
num_grid_y:
num_grid_z:
max_neighbour_distance:
num_c_out_each_grid:
use_xyz:
neighbor_type: 1: ball, others: cube:
pooling_type: 0: avg_pool, 1: random choice
Returns:
new_features: (M1 + M2 ..., num_c_out)
"""
assert support_xyz.is_contiguous()
assert support_features.is_contiguous()
assert xyz_batch_cnt.is_contiguous()
assert new_xyz.is_contiguous()
assert new_xyz_batch_cnt.is_contiguous()
num_total_grids = num_grid_x * num_grid_y * num_grid_z
num_c_out = num_c_out_each_grid * num_total_grids
N, num_c_in = support_features.shape
M = new_xyz.shape[0]
assert num_c_in % num_c_out_each_grid == 0, \
f'the input channels ({num_c_in}) should be an integral multiple of num_c_out_each_grid({num_c_out_each_grid})'
while True:
new_features = support_features.new_zeros((M, num_c_out))
new_local_xyz = support_features.new_zeros((M, 3 * num_total_grids))
point_cnt_of_grid = xyz_batch_cnt.new_zeros((M, num_total_grids))
num_max_sum_points = num_mean_points_per_grid * M
grouped_idxs = xyz_batch_cnt.new_zeros((num_max_sum_points, 3))
num_cum_sum = pointnet2.vector_pool_wrapper(
support_xyz, xyz_batch_cnt, support_features, new_xyz, new_xyz_batch_cnt,
new_features, new_local_xyz, point_cnt_of_grid, grouped_idxs,
num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance, use_xyz,
num_max_sum_points, nsample, neighbor_type, pooling_type
)
num_mean_points_per_grid = num_cum_sum // M + int(num_cum_sum % M > 0)
if num_cum_sum <= num_max_sum_points:
break
grouped_idxs = grouped_idxs[:num_cum_sum]
normalizer = torch.clamp_min(point_cnt_of_grid[:, :, None].float(), min=1e-6)
new_features = (new_features.view(-1, num_total_grids, num_c_out_each_grid) / normalizer).view(-1, num_c_out)
if use_xyz:
new_local_xyz = (new_local_xyz.view(-1, num_total_grids, 3) / normalizer).view(-1, num_total_grids * 3)
num_mean_points_per_grid = torch.Tensor([num_mean_points_per_grid]).int()
nsample = torch.Tensor([nsample]).int()
ctx.vector_pool_for_backward = (point_cnt_of_grid, grouped_idxs, N, num_c_in)
ctx.mark_non_differentiable(new_local_xyz, num_mean_points_per_grid, nsample, point_cnt_of_grid)
return new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid
@staticmethod
def backward(ctx, grad_new_features: torch.Tensor, grad_local_xyz: torch.Tensor, grad_num_cum_sum, grad_point_cnt_of_grid):
"""
Args:
ctx:
grad_new_features: (M1 + M2 ..., num_c_out), num_c_out = num_c_out_each_grid * num_total_grids
Returns:
grad_support_features: (N1 + N2 ..., C_in)
"""
point_cnt_of_grid, grouped_idxs, N, num_c_in = ctx.vector_pool_for_backward
grad_support_features = grad_new_features.new_zeros((N, num_c_in))
pointnet2.vector_pool_grad_wrapper(
grad_new_features.contiguous(), point_cnt_of_grid, grouped_idxs,
grad_support_features
)
return None, None, grad_support_features, None, None, None, None, None, None, None, None, None, None, None, None
vector_pool_with_voxel_query_op = VectorPoolWithVoxelQuery.apply
if __name__ == '__main__': if __name__ == '__main__':
pass pass
...@@ -6,13 +6,15 @@ ...@@ -6,13 +6,15 @@
#include "sampling_gpu.h" #include "sampling_gpu.h"
#include "interpolate_gpu.h" #include "interpolate_gpu.h"
#include "voxel_query_gpu.h" #include "voxel_query_gpu.h"
#include "vector_pool_gpu.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack"); m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack");
m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack"); m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("stack_farthest_point_sampling_wrapper", &stack_farthest_point_sampling_wrapper, "stack_farthest_point_sampling_wrapper");
m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack"); m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack");
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack"); m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack");
...@@ -20,4 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -20,4 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack"); m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack"); m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack");
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack"); m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack");
m.def("query_stacked_local_neighbor_idxs_wrapper_stack", &query_stacked_local_neighbor_idxs_wrapper_stack, "query_stacked_local_neighbor_idxs_wrapper_stack");
m.def("query_three_nn_by_stacked_local_idxs_wrapper_stack", &query_three_nn_by_stacked_local_idxs_wrapper_stack, "query_three_nn_by_stacked_local_idxs_wrapper_stack");
m.def("vector_pool_wrapper", &vector_pool_wrapper_stack, "vector_pool_grad_wrapper_stack");
m.def("vector_pool_grad_wrapper", &vector_pool_grad_wrapper_stack, "vector_pool_grad_wrapper_stack");
} }
...@@ -21,7 +21,7 @@ extern THCState *state; ...@@ -21,7 +21,7 @@ extern THCState *state;
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int furthest_point_sampling_wrapper(int b, int n, int m, int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(points_tensor); CHECK_INPUT(points_tensor);
...@@ -32,6 +32,29 @@ int furthest_point_sampling_wrapper(int b, int n, int m, ...@@ -32,6 +32,29 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
float *temp = temp_tensor.data<float>(); float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>(); int *idx = idx_tensor.data<int>();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx); farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1;
}
int stack_farthest_point_sampling_wrapper(at::Tensor points_tensor,
at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor,
at::Tensor num_sampled_points_tensor) {
CHECK_INPUT(points_tensor);
CHECK_INPUT(temp_tensor);
CHECK_INPUT(idx_tensor);
CHECK_INPUT(xyz_batch_cnt_tensor);
CHECK_INPUT(num_sampled_points_tensor);
int batch_size = xyz_batch_cnt_tensor.size(0);
int N = points_tensor.size(0);
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *xyz_batch_cnt = xyz_batch_cnt_tensor.data<int>();
int *idx = idx_tensor.data<int>();
int *num_sampled_points = num_sampled_points_tensor.data<int>();
stack_farthest_point_sampling_kernel_launcher(N, batch_size, points, temp, xyz_batch_cnt, idx, num_sampled_points);
return 1; return 1;
} }
\ No newline at end of file
...@@ -22,7 +22,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i ...@@ -22,7 +22,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
template <unsigned int block_size> template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m, __global__ void farthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3) // dataset: (B, N, 3)
// tmp: (B, N) // tmp: (B, N)
...@@ -139,7 +139,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, ...@@ -139,7 +139,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
} }
} }
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs) { const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, 3) // dataset: (B, N, 3)
// tmp: (B, N) // tmp: (B, N)
...@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m, ...@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch (n_threads) { switch (n_threads) {
case 1024: case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 512: case 512:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 256: case 256:
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 128: case 128:
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 64: case 64:
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 32: case 32:
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 16: case 16:
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 8: case 8:
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 4: case 4:
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 2: case 2:
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 1: case 1:
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
default: default:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
} }
err = cudaGetLastError(); err = cudaGetLastError();
...@@ -182,3 +182,168 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m, ...@@ -182,3 +182,168 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
exit(-1); exit(-1);
} }
} }
template <unsigned int block_size>
__global__ void stack_farthest_point_sampling_kernel(int batch_size, int N,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points) {
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// num_sampled_points: [M1, M2, ...] int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int bs_idx = blockIdx.x;
int xyz_batch_start_idx = 0, idxs_start_idx = 0;
for (int k = 0; k < bs_idx; k++){
xyz_batch_start_idx += xyz_batch_cnt[k];
idxs_start_idx += num_sampled_points[k];
}
dataset += xyz_batch_start_idx * 3;
temp += xyz_batch_start_idx;
idxs += idxs_start_idx;
int n = xyz_batch_cnt[bs_idx];
int m = num_sampled_points[bs_idx];
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0) idxs[0] = xyz_batch_start_idx;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
}
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0)
idxs[j] = old + xyz_batch_start_idx;
}
}
void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points) {
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// npoint: int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
cudaError_t err;
unsigned int n_threads = opt_n_threads(N);
stack_farthest_point_sampling_kernel<1024><<<batch_size, 1024>>>(
batch_size, N, dataset, temp, xyz_batch_cnt, idxs, num_sampled_points
);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
\ No newline at end of file
...@@ -6,10 +6,18 @@ ...@@ -6,10 +6,18 @@
#include<vector> #include<vector>
int furthest_point_sampling_wrapper(int b, int n, int m, int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs); const float *dataset, float *temp, int *idxs);
int stack_farthest_point_sampling_wrapper(
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor,
at::Tensor idx_tensor, at::Tensor num_sampled_points_tensor);
void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points);
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment