Unverified Commit a991105c authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

Release the codes of PV-RCNN++, update OpenPCDet to v0.5.2

Release the codes of PV-RCNN++, update OpenPCDet to v0.5.2
parents 1483517a b6fbf07f
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
`OpenPCDet` is a clear, simple, self-contained open source project for LiDAR-based 3D object detection. `OpenPCDet` is a clear, simple, self-contained open source project for LiDAR-based 3D object detection.
It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/1812.04244), [`[Part-A2-Net]`](https://arxiv.org/abs/1907.03670), [`[PV-RCNN]`](https://arxiv.org/abs/1912.13192) and [`[Voxel R-CNN]`](https://arxiv.org/abs/2012.15712). It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/1812.04244), [`[Part-A2-Net]`](https://arxiv.org/abs/1907.03670), [`[PV-RCNN]`](https://arxiv.org/abs/1912.13192), [`[Voxel R-CNN]`](https://arxiv.org/abs/2012.15712) and [`[PV-RCNN++]`](https://arxiv.org/abs/2102.00463).
**NEW**: `OpenPCDet` has been updated to `v0.5.0` (Dec. 2021). **Highlights**:
* `OpenPCDet` has been updated to `v0.5.2` (Jan. 2022).
* The codes of PV-RCNN++ has been supported.
## Overview ## Overview
- [Changelog](#changelog) - [Changelog](#changelog)
...@@ -19,6 +21,12 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18 ...@@ -19,6 +21,12 @@ It is also the official code release of [`[PointRCNN]`](https://arxiv.org/abs/18
## Changelog ## Changelog
[2022-01-05] **NEW:** Update `OpenPCDet` to v0.5.2:
* The code of [PV-RCNN++](https://arxiv.org/abs/2102.00463) has been released to this repo, with higher performance, faster training/inference speed and less memory consumption than PV-RCNN.
* Add performance of several models trained with full training set of [Waymo Open Dataset](#waymo-open-dataset-baselines).
* Support Lyft dataset, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/720).
[2021-12-09] **NEW:** Update `OpenPCDet` to v0.5.1: [2021-12-09] **NEW:** Update `OpenPCDet` to v0.5.1:
* Add PointPillar related baseline configs/results on [Waymo Open Dataset](#waymo-open-dataset-baselines). * Add PointPillar related baseline configs/results on [Waymo Open Dataset](#waymo-open-dataset-baselines).
* Support Pandaset dataloader, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/396). * Support Pandaset dataloader, see the pull request [here](https://github.com/open-mmlab/OpenPCDet/pull/396).
...@@ -108,7 +116,7 @@ Contributions are also welcomed. ...@@ -108,7 +116,7 @@ Contributions are also welcomed.
### KITTI 3D Object Detection Baselines ### KITTI 3D Object Detection Baselines
Selected supported methods are shown in the below table. The results are the 3D detection performance of moderate difficulty on the *val* set of KITTI dataset. Selected supported methods are shown in the below table. The results are the 3D detection performance of moderate difficulty on the *val* set of KITTI dataset.
* All models are trained with 8 GTX 1080Ti GPUs and are available for download. * All LiDAR-based models are trained with 8 GTX 1080Ti GPUs and are available for download.
* The training time is measured with 8 TITAN XP GPUs and PyTorch 1.5. * The training time is measured with 8 TITAN XP GPUs and PyTorch 1.5.
| | training time | Car@R11 | Pedestrian@R11 | Cyclist@R11 | download | | | training time | Car@R11 | Pedestrian@R11 | Cyclist@R11 | download |
...@@ -129,7 +137,7 @@ Selected supported methods are shown in the below table. The results are the 3D ...@@ -129,7 +137,7 @@ Selected supported methods are shown in the below table. The results are the 3D
We provide the setting of [`DATA_CONFIG.SAMPLED_INTERVAL`](tools/cfgs/dataset_configs/waymo_dataset.yaml) on the Waymo Open Dataset (WOD) to subsample partial samples for training and evaluation, We provide the setting of [`DATA_CONFIG.SAMPLED_INTERVAL`](tools/cfgs/dataset_configs/waymo_dataset.yaml) on the Waymo Open Dataset (WOD) to subsample partial samples for training and evaluation,
so you could also play with WOD by setting a smaller `DATA_CONFIG.SAMPLED_INTERVAL` even if you only have limited GPU resources. so you could also play with WOD by setting a smaller `DATA_CONFIG.SAMPLED_INTERVAL` even if you only have limited GPU resources.
By default, all models are trained with **20% data (~32k frames)** of all the training samples on 8 GTX 1080Ti GPUs, and the results of each cell here are mAP/mAPH calculated by the official Waymo evaluation metrics on the **whole** validation set (version 1.2). By default, all models are trained with **a single frame** of **20% data (~32k frames)** of all the training samples on 8 GTX 1080Ti GPUs, and the results of each cell here are mAP/mAPH calculated by the official Waymo evaluation metrics on the **whole** validation set (version 1.2).
| Performance@(train with 20\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 | | Performance@(train with 20\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:| |---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
...@@ -141,6 +149,22 @@ By default, all models are trained with **20% data (~32k frames)** of all the tr ...@@ -141,6 +149,22 @@ By default, all models are trained with **20% data (~32k frames)** of all the tr
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 74.66/74.12 |65.82/65.32 |71.71/62.24 |62.46/54.06 |66.53/65.18 |64.05/62.75 | | [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 74.66/74.12 |65.82/65.32 |71.71/62.24 |62.46/54.06 |66.53/65.18 |64.05/62.75 |
| [PV-RCNN (AnchorHead)](tools/cfgs/waymo_models/pv_rcnn.yaml) | 75.41/74.74 |67.44/66.80 |71.98/61.24 |63.70/53.95 |65.88/64.25 |63.39/61.82 | | [PV-RCNN (AnchorHead)](tools/cfgs/waymo_models/pv_rcnn.yaml) | 75.41/74.74 |67.44/66.80 |71.98/61.24 |63.70/53.95 |65.88/64.25 |63.39/61.82 |
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 75.95/75.43 |68.02/67.54 |75.94/69.40 |67.66/61.62 |70.18/68.98 |67.73/66.57| | [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 75.95/75.43 |68.02/67.54 |75.94/69.40 |67.66/61.62 |70.18/68.98 |67.73/66.57|
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 77.82/77.32| 69.07/68.62| 77.99/71.36| 69.92/63.74| 71.80/70.71| 69.31/68.26|
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) |77.61/77.14| 69.18/68.75| 79.42/73.31| 70.88/65.21| 72.50/71.39| 69.84/68.77|
Here we also provide the performance of several models trained on the full training set (refer to the paper of [PV-RCNN++](https://arxiv.org/abs/2102.00463)):
| Performance@(train with 100\% Data) | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
|---------------------------------------------|----------:|:-------:|:-------:|:-------:|:-------:|:-------:|
| [SECOND](tools/cfgs/waymo_models/second.yaml) | 72.27/71.69 | 63.85/63.33 | 68.70/58.18 | 60.72/51.31 | 60.62/59.28 | 58.34/57.05|
| [Part-A2-Anchor](tools/cfgs/waymo_models/PartA2.yaml) | 77.05/76.51 | 68.47/67.97 | 75.24/66.87 | 66.18/58.62 | 68.60/67.36 | 66.13/64.93 |
| [PV-RCNN (CenterHead)](tools/cfgs/waymo_models/pv_rcnn_with_centerhead_rpn.yaml) | 78.00/77.50 | 69.43/68.98 | 79.21/73.03 | 70.42/64.72 | 71.46/70.27 | 68.95/67.79|
| [PV-RCNN++](tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml) | 79.10/78.63 | 70.34/69.91 | 80.62/74.62 | 71.86/66.30 | 73.49/72.38 | 70.70/69.62 |
| [PV-RCNN++ (ResNet)](tools/cfgs/waymo_models/pv_rcnn_plusplus_resnet.yaml) |79.25/78.78 | 70.61/70.18 | 81.83/76.28 | 73.17/68.00 | 73.72/72.66 | 71.21/70.19|
We could not provide the above pretrained models due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), We could not provide the above pretrained models due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/),
but you could easily achieve similar performance by training with the default configs. but you could easily achieve similar performance by training with the default configs.
......
...@@ -86,7 +86,7 @@ OpenPCDet ...@@ -86,7 +86,7 @@ OpenPCDet
```shell script ```shell script
pip3 install --upgrade pip pip3 install --upgrade pip
# tf 2.0.0 # tf 2.0.0
pip3 install waymo-open-dataset-tf-2-0-0==1.2.0 --user pip3 install waymo-open-dataset-tf-2-5-0 --user
``` ```
* Extract point cloud data from tfrecord and generate data infos by running the following command (it takes several hours, * Extract point cloud data from tfrecord and generate data infos by running the following command (it takes several hours,
......
...@@ -30,7 +30,7 @@ class PointFeatureEncoder(object): ...@@ -30,7 +30,7 @@ class PointFeatureEncoder(object):
data_dict['points'] data_dict['points']
) )
data_dict['use_lead_xyz'] = use_lead_xyz data_dict['use_lead_xyz'] = use_lead_xyz
if self.point_encoding_config.get('filter_sweeps', False) and 'timestamp' in self.src_feature_list: if self.point_encoding_config.get('filter_sweeps', False) and 'timestamp' in self.src_feature_list:
max_sweeps = self.point_encoding_config.max_sweeps max_sweeps = self.point_encoding_config.max_sweeps
idx = self.src_feature_list.index('timestamp') idx = self.src_feature_list.index('timestamp')
......
import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y): ...@@ -40,6 +42,85 @@ def bilinear_interpolate_torch(im, x, y):
return ans return ans
def sample_points_with_roi(rois, points, sample_radius_with_roi, num_max_points_of_part=200000):
"""
Args:
rois: (M, 7 + C)
points: (N, 3)
sample_radius_with_roi:
num_max_points_of_part:
Returns:
sampled_points: (N_out, 3)
"""
if points.shape[0] < num_max_points_of_part:
distance = (points[:, None, :] - rois[None, :, 0:3]).norm(dim=-1)
min_dis, min_dis_roi_idx = distance.min(dim=-1)
roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1)
point_mask = min_dis < roi_max_dim + sample_radius_with_roi
else:
start_idx = 0
point_mask_list = []
while start_idx < points.shape[0]:
distance = (points[start_idx:start_idx + num_max_points_of_part, None, :] - rois[None, :, 0:3]).norm(dim=-1)
min_dis, min_dis_roi_idx = distance.min(dim=-1)
roi_max_dim = (rois[min_dis_roi_idx, 3:6] / 2).norm(dim=-1)
cur_point_mask = min_dis < roi_max_dim + sample_radius_with_roi
point_mask_list.append(cur_point_mask)
start_idx += num_max_points_of_part
point_mask = torch.cat(point_mask_list, dim=0)
sampled_points = points[:1] if point_mask.sum() == 0 else points[point_mask, :]
return sampled_points, point_mask
def sector_fps(points, num_sampled_points, num_sectors):
"""
Args:
points: (N, 3)
num_sampled_points: int
num_sectors: int
Returns:
sampled_points: (N_out, 3)
"""
sector_size = np.pi * 2 / num_sectors
point_angles = torch.atan2(points[:, 1], points[:, 0]) + np.pi
sector_idx = (point_angles / sector_size).floor().clamp(min=0, max=num_sectors)
xyz_points_list = []
xyz_batch_cnt = []
num_sampled_points_list = []
for k in range(num_sectors):
mask = (sector_idx == k)
cur_num_points = mask.sum().item()
if cur_num_points > 0:
xyz_points_list.append(points[mask])
xyz_batch_cnt.append(cur_num_points)
ratio = cur_num_points / points.shape[0]
num_sampled_points_list.append(
min(cur_num_points, math.ceil(ratio * num_sampled_points))
)
if len(xyz_batch_cnt) == 0:
xyz_points_list.append(points)
xyz_batch_cnt.append(len(points))
num_sampled_points_list.append(num_sampled_points)
print(f'Warning: empty sector points detected in SectorFPS: points.shape={points.shape}')
xyz = torch.cat(xyz_points_list, dim=0)
xyz_batch_cnt = torch.tensor(xyz_batch_cnt, device=points.device).int()
sampled_points_batch_cnt = torch.tensor(num_sampled_points_list, device=points.device).int()
sampled_pt_idxs = pointnet2_stack_utils.stack_farthest_point_sample(
xyz.contiguous(), xyz_batch_cnt, sampled_points_batch_cnt
).long()
sampled_points = xyz[sampled_pt_idxs]
return sampled_points
class VoxelSetAbstraction(nn.Module): class VoxelSetAbstraction(nn.Module):
def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None, def __init__(self, model_cfg, voxel_size, point_cloud_range, num_bev_features=None,
num_rawpoint_features=None, **kwargs): num_rawpoint_features=None, **kwargs):
...@@ -58,38 +139,31 @@ class VoxelSetAbstraction(nn.Module): ...@@ -58,38 +139,31 @@ class VoxelSetAbstraction(nn.Module):
if src_name in ['bev', 'raw_points']: if src_name in ['bev', 'raw_points']:
continue continue
self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR self.downsample_times_map[src_name] = SA_cfg[src_name].DOWNSAMPLE_FACTOR
mlps = SA_cfg[src_name].MLPS
for k in range(len(mlps)): if SA_cfg[src_name].get('INPUT_CHANNELS', None) is None:
mlps[k] = [mlps[k][0]] + mlps[k] input_channels = SA_cfg[src_name].MLPS[0][0] \
cur_layer = pointnet2_stack_modules.StackSAModuleMSG( if isinstance(SA_cfg[src_name].MLPS[0], list) else SA_cfg[src_name].MLPS[0]
radii=SA_cfg[src_name].POOL_RADIUS, else:
nsamples=SA_cfg[src_name].NSAMPLE, input_channels = SA_cfg[src_name]['INPUT_CHANNELS']
mlps=mlps,
use_xyz=True, cur_layer, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
pool_method='max_pool', input_channels=input_channels, config=SA_cfg[src_name]
) )
self.SA_layers.append(cur_layer) self.SA_layers.append(cur_layer)
self.SA_layer_names.append(src_name) self.SA_layer_names.append(src_name)
c_in += sum([x[-1] for x in mlps]) c_in += cur_num_c_out
if 'bev' in self.model_cfg.FEATURES_SOURCE: if 'bev' in self.model_cfg.FEATURES_SOURCE:
c_bev = num_bev_features c_bev = num_bev_features
c_in += c_bev c_in += c_bev
if 'raw_points' in self.model_cfg.FEATURES_SOURCE: if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
mlps = SA_cfg['raw_points'].MLPS self.SA_rawpoints, cur_num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
for k in range(len(mlps)): input_channels=num_rawpoint_features - 3, config=SA_cfg['raw_points']
mlps[k] = [num_rawpoint_features - 3] + mlps[k]
self.SA_rawpoints = pointnet2_stack_modules.StackSAModuleMSG(
radii=SA_cfg['raw_points'].POOL_RADIUS,
nsamples=SA_cfg['raw_points'].NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method='max_pool'
) )
c_in += sum([x[-1] for x in mlps])
c_in += cur_num_c_out
self.vsa_point_feature_fusion = nn.Sequential( self.vsa_point_feature_fusion = nn.Sequential(
nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False), nn.Linear(c_in, self.model_cfg.NUM_OUTPUT_FEATURES, bias=False),
...@@ -100,23 +174,64 @@ class VoxelSetAbstraction(nn.Module): ...@@ -100,23 +174,64 @@ class VoxelSetAbstraction(nn.Module):
self.num_point_features_before_fusion = c_in self.num_point_features_before_fusion = c_in
def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride): def interpolate_from_bev_features(self, keypoints, bev_features, batch_size, bev_stride):
x_idxs = (keypoints[:, :, 0] - self.point_cloud_range[0]) / self.voxel_size[0] """
y_idxs = (keypoints[:, :, 1] - self.point_cloud_range[1]) / self.voxel_size[1] Args:
keypoints: (N1 + N2 + ..., 4)
bev_features: (B, C, H, W)
batch_size:
bev_stride:
Returns:
point_bev_features: (N1 + N2 + ..., C)
"""
x_idxs = (keypoints[:, 1] - self.point_cloud_range[0]) / self.voxel_size[0]
y_idxs = (keypoints[:, 2] - self.point_cloud_range[1]) / self.voxel_size[1]
x_idxs = x_idxs / bev_stride x_idxs = x_idxs / bev_stride
y_idxs = y_idxs / bev_stride y_idxs = y_idxs / bev_stride
point_bev_features_list = [] point_bev_features_list = []
for k in range(batch_size): for k in range(batch_size):
cur_x_idxs = x_idxs[k] bs_mask = (keypoints[:, 0] == k)
cur_y_idxs = y_idxs[k]
cur_x_idxs = x_idxs[bs_mask]
cur_y_idxs = y_idxs[bs_mask]
cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C) cur_bev_features = bev_features[k].permute(1, 2, 0) # (H, W, C)
point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs) point_bev_features = bilinear_interpolate_torch(cur_bev_features, cur_x_idxs, cur_y_idxs)
point_bev_features_list.append(point_bev_features.unsqueeze(dim=0)) point_bev_features_list.append(point_bev_features)
point_bev_features = torch.cat(point_bev_features_list, dim=0) # (B, N, C0) point_bev_features = torch.cat(point_bev_features_list, dim=0) # (N1 + N2 + ..., C)
return point_bev_features return point_bev_features
def sectorized_proposal_centric_sampling(self, roi_boxes, points):
"""
Args:
roi_boxes: (M, 7 + C)
points: (N, 3)
Returns:
sampled_points: (N_out, 3)
"""
sampled_points, _ = sample_points_with_roi(
rois=roi_boxes, points=points,
sample_radius_with_roi=self.model_cfg.SPC_SAMPLING.SAMPLE_RADIUS_WITH_ROI,
num_max_points_of_part=self.model_cfg.SPC_SAMPLING.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000)
)
sampled_points = sector_fps(
points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS,
num_sectors=self.model_cfg.SPC_SAMPLING.NUM_SECTORS
)
return sampled_points
def get_sampled_points(self, batch_dict): def get_sampled_points(self, batch_dict):
"""
Args:
batch_dict:
Returns:
keypoints: (N1 + N2 + ..., 4), where 4 indicates [bs_idx, x, y, z]
"""
batch_size = batch_dict['batch_size'] batch_size = batch_dict['batch_size']
if self.model_cfg.POINT_SOURCE == 'raw_points': if self.model_cfg.POINT_SOURCE == 'raw_points':
src_points = batch_dict['points'][:, 1:4] src_points = batch_dict['points'][:, 1:4]
...@@ -136,7 +251,7 @@ class VoxelSetAbstraction(nn.Module): ...@@ -136,7 +251,7 @@ class VoxelSetAbstraction(nn.Module):
bs_mask = (batch_indices == bs_idx) bs_mask = (batch_indices == bs_idx)
sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3) sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3)
if self.model_cfg.SAMPLE_METHOD == 'FPS': if self.model_cfg.SAMPLE_METHOD == 'FPS':
cur_pt_idxs = pointnet2_stack_utils.furthest_point_sample( cur_pt_idxs = pointnet2_stack_utils.farthest_point_sample(
sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS
).long() ).long()
...@@ -147,16 +262,75 @@ class VoxelSetAbstraction(nn.Module): ...@@ -147,16 +262,75 @@ class VoxelSetAbstraction(nn.Module):
keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0) keypoints = sampled_points[0][cur_pt_idxs[0]].unsqueeze(dim=0)
elif self.model_cfg.SAMPLE_METHOD == 'FastFPS': elif self.model_cfg.SAMPLE_METHOD == 'SPC':
raise NotImplementedError cur_keypoints = self.sectorized_proposal_centric_sampling(
roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points[0]
)
bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx
keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1)
else: else:
raise NotImplementedError raise NotImplementedError
keypoints_list.append(keypoints) keypoints_list.append(keypoints)
keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) or (N1 + N2 + ..., 4)
if len(keypoints.shape) == 3:
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1, 1)
keypoints = torch.cat((batch_idx.float(), keypoints.view(-1, 3)), dim=1)
return keypoints return keypoints
@staticmethod
def aggregate_keypoint_features_from_one_source(
batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt,
filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=200000, rois=None
):
"""
Args:
aggregate_func:
xyz: (N, 3)
xyz_features: (N, C)
xyz_bs_idxs: (N)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size), [N1, N2, ...]
filter_neighbors_with_roi: True/False
radius_of_neighbor: float
num_max_points_of_part: int
rois: (batch_size, num_rois, 7 + C)
Returns:
"""
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
if filter_neighbors_with_roi:
point_features = torch.cat((xyz, xyz_features), dim=-1) if xyz_features is not None else xyz
point_features_list = []
for bs_idx in range(batch_size):
bs_mask = (xyz_bs_idxs == bs_idx)
_, valid_mask = sample_points_with_roi(
rois=rois[bs_idx], points=xyz[bs_mask],
sample_radius_with_roi=radius_of_neighbor, num_max_points_of_part=num_max_points_of_part,
)
point_features_list.append(point_features[bs_mask][valid_mask])
xyz_batch_cnt[bs_idx] = valid_mask.sum()
valid_point_features = torch.cat(point_features_list, dim=0)
xyz = valid_point_features[:, 0:3]
xyz_features = valid_point_features[:, 3:] if xyz_features is not None else None
else:
for bs_idx in range(batch_size):
xyz_batch_cnt[bs_idx] = (xyz_bs_idxs == bs_idx).sum()
pooled_points, pooled_features = aggregate_func(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=xyz_features.contiguous(),
)
return pooled_features
def forward(self, batch_dict): def forward(self, batch_dict):
""" """
Args: Args:
...@@ -185,56 +359,53 @@ class VoxelSetAbstraction(nn.Module): ...@@ -185,56 +359,53 @@ class VoxelSetAbstraction(nn.Module):
) )
point_features_list.append(point_bev_features) point_features_list.append(point_bev_features)
batch_size, num_keypoints, _ = keypoints.shape batch_size = batch_dict['batch_size']
new_xyz = keypoints.view(-1, 3)
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int().fill_(num_keypoints) new_xyz = keypoints[:, 1:4].contiguous()
new_xyz_batch_cnt = new_xyz.new_zeros(batch_size).int()
for k in range(batch_size):
new_xyz_batch_cnt[k] = (keypoints[:, 0] == k).sum()
if 'raw_points' in self.model_cfg.FEATURES_SOURCE: if 'raw_points' in self.model_cfg.FEATURES_SOURCE:
raw_points = batch_dict['points'] raw_points = batch_dict['points']
xyz = raw_points[:, 1:4]
xyz_batch_cnt = xyz.new_zeros(batch_size).int() pooled_features = self.aggregate_keypoint_features_from_one_source(
for bs_idx in range(batch_size): batch_size=batch_size, aggregate_func=self.SA_rawpoints,
xyz_batch_cnt[bs_idx] = (raw_points[:, 0] == bs_idx).sum() xyz=raw_points[:, 1:4],
point_features = raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None xyz_features=raw_points[:, 4:].contiguous() if raw_points.shape[1] > 4 else None,
xyz_bs_idxs=raw_points[:, 0],
pooled_points, pooled_features = self.SA_rawpoints( new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt,
xyz=xyz.contiguous(), filter_neighbors_with_roi=self.model_cfg.SA_LAYER['raw_points'].get('FILTER_NEIGHBOR_WITH_ROI', False),
xyz_batch_cnt=xyz_batch_cnt, radius_of_neighbor=self.model_cfg.SA_LAYER['raw_points'].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None),
new_xyz=new_xyz, rois=batch_dict.get('rois', None)
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=point_features,
) )
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1)) point_features_list.append(pooled_features)
for k, src_name in enumerate(self.SA_layer_names): for k, src_name in enumerate(self.SA_layer_names):
cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices cur_coords = batch_dict['multi_scale_3d_features'][src_name].indices
cur_features = batch_dict['multi_scale_3d_features'][src_name].features.contiguous()
xyz = common_utils.get_voxel_centers( xyz = common_utils.get_voxel_centers(
cur_coords[:, 1:4], cur_coords[:, 1:4], downsample_times=self.downsample_times_map[src_name],
downsample_times=self.downsample_times_map[src_name], voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range
voxel_size=self.voxel_size,
point_cloud_range=self.point_cloud_range
) )
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
for bs_idx in range(batch_size): pooled_features = self.aggregate_keypoint_features_from_one_source(
xyz_batch_cnt[bs_idx] = (cur_coords[:, 0] == bs_idx).sum() batch_size=batch_size, aggregate_func=self.SA_layers[k],
xyz=xyz.contiguous(), xyz_features=cur_features, xyz_bs_idxs=cur_coords[:, 0],
pooled_points, pooled_features = self.SA_layers[k]( new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt,
xyz=xyz.contiguous(), filter_neighbors_with_roi=self.model_cfg.SA_LAYER[src_name].get('FILTER_NEIGHBOR_WITH_ROI', False),
xyz_batch_cnt=xyz_batch_cnt, radius_of_neighbor=self.model_cfg.SA_LAYER[src_name].get('RADIUS_OF_NEIGHBOR_WITH_ROI', None),
new_xyz=new_xyz, rois=batch_dict.get('rois', None)
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=batch_dict['multi_scale_3d_features'][src_name].features.contiguous(),
) )
point_features_list.append(pooled_features.view(batch_size, num_keypoints, -1))
point_features = torch.cat(point_features_list, dim=2) point_features_list.append(pooled_features)
batch_idx = torch.arange(batch_size, device=keypoints.device).view(-1, 1).repeat(1, keypoints.shape[1]).view(-1) point_features = torch.cat(point_features_list, dim=-1)
point_coords = torch.cat((batch_idx.view(-1, 1).float(), keypoints.view(-1, 3)), dim=1)
batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1]) batch_dict['point_features_before_fusion'] = point_features.view(-1, point_features.shape[-1])
point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1])) point_features = self.vsa_point_feature_fusion(point_features.view(-1, point_features.shape[-1]))
batch_dict['point_features'] = point_features # (BxN, C) batch_dict['point_features'] = point_features # (BxN, C)
batch_dict['point_coords'] = point_coords # (BxN, 4) batch_dict['point_coords'] = keypoints # (BxN, 4)
return batch_dict return batch_dict
...@@ -174,7 +174,7 @@ class PointNet2Backbone(nn.Module): ...@@ -174,7 +174,7 @@ class PointNet2Backbone(nn.Module):
else: else:
last_num_points = self.num_points_each_layer[i - 1] last_num_points = self.num_points_each_layer[i - 1]
cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points] cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points]
cur_pt_idxs = pointnet2_utils_stack.furthest_point_sample( cur_pt_idxs = pointnet2_utils_stack.farthest_point_sample(
cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i] cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i]
).long()[0] ).long()[0]
if cur_xyz.shape[0] < self.num_points_each_layer[i]: if cur_xyz.shape[0] < self.num_points_each_layer[i]:
......
...@@ -8,6 +8,7 @@ from .second_net_iou import SECONDNetIoU ...@@ -8,6 +8,7 @@ from .second_net_iou import SECONDNetIoU
from .caddn import CaDDN from .caddn import CaDDN
from .voxel_rcnn import VoxelRCNN from .voxel_rcnn import VoxelRCNN
from .centerpoint import CenterPoint from .centerpoint import CenterPoint
from .pv_rcnn_plusplus import PVRCNNPlusPlus
__all__ = { __all__ = {
'Detector3DTemplate': Detector3DTemplate, 'Detector3DTemplate': Detector3DTemplate,
...@@ -19,7 +20,8 @@ __all__ = { ...@@ -19,7 +20,8 @@ __all__ = {
'SECONDNetIoU': SECONDNetIoU, 'SECONDNetIoU': SECONDNetIoU,
'CaDDN': CaDDN, 'CaDDN': CaDDN,
'VoxelRCNN': VoxelRCNN, 'VoxelRCNN': VoxelRCNN,
'CenterPoint': CenterPoint 'CenterPoint': CenterPoint,
'PVRCNNPlusPlus': PVRCNNPlusPlus
} }
......
from .detector3d_template import Detector3DTemplate
class PVRCNNPlusPlus(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
batch_dict = self.vfe(batch_dict)
batch_dict = self.backbone_3d(batch_dict)
batch_dict = self.map_to_bev_module(batch_dict)
batch_dict = self.backbone_2d(batch_dict)
batch_dict = self.dense_head(batch_dict)
batch_dict = self.roi_head.proposal_layer(
batch_dict, nms_config=self.roi_head.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.roi_head.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
batch_dict['roi_targets_dict'] = targets_dict
num_rois_per_scene = targets_dict['rois'].shape[1]
if 'roi_valid_num' in batch_dict:
batch_dict['roi_valid_num'] = [num_rois_per_scene for _ in range(batch_dict['batch_size'])]
batch_dict = self.pfe(batch_dict)
batch_dict = self.point_head(batch_dict)
batch_dict = self.roi_head(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
if self.point_head is not None:
loss_point, tb_dict = self.point_head.get_loss(tb_dict)
else:
loss_point = 0
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss_rpn + loss_point + loss_rcnn
return loss, tb_dict, disp_dict
...@@ -10,21 +10,12 @@ class PVRCNNHead(RoIHeadTemplate): ...@@ -10,21 +10,12 @@ class PVRCNNHead(RoIHeadTemplate):
super().__init__(num_class=num_class, model_cfg=model_cfg) super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg self.model_cfg = model_cfg
mlps = self.model_cfg.ROI_GRID_POOL.MLPS self.roi_grid_pool_layer, num_c_out = pointnet2_stack_modules.build_local_aggregation_module(
for k in range(len(mlps)): input_channels=input_channels, config=self.model_cfg.ROI_GRID_POOL
mlps[k] = [input_channels] + mlps[k]
self.roi_grid_pool_layer = pointnet2_stack_modules.StackSAModuleMSG(
radii=self.model_cfg.ROI_GRID_POOL.POOL_RADIUS,
nsamples=self.model_cfg.ROI_GRID_POOL.NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method=self.model_cfg.ROI_GRID_POOL.POOL_METHOD,
) )
GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
c_out = sum([x[-1] for x in mlps]) pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * num_c_out
pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * c_out
shared_fc_list = [] shared_fc_list = []
for k in range(0, self.model_cfg.SHARED_FC.__len__()): for k in range(0, self.model_cfg.SHARED_FC.__len__()):
...@@ -150,9 +141,11 @@ class PVRCNNHead(RoIHeadTemplate): ...@@ -150,9 +141,11 @@ class PVRCNNHead(RoIHeadTemplate):
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST'] batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
) )
if self.training: if self.training:
targets_dict = self.assign_targets(batch_dict) targets_dict = batch_dict.get('roi_targets_dict', None)
batch_dict['rois'] = targets_dict['rois'] if targets_dict is None:
batch_dict['roi_labels'] = targets_dict['roi_labels'] targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
# RoI aware pooling # RoI aware pooling
pooled_features = self.roi_grid_pool(batch_dict) # (BxN, 6x6x6, C) pooled_features = self.roi_grid_pool(batch_dict) # (BxN, 6x6x6, C)
......
...@@ -31,7 +31,7 @@ class _PointnetSAModuleBase(nn.Module): ...@@ -31,7 +31,7 @@ class _PointnetSAModuleBase(nn.Module):
if new_xyz is None: if new_xyz is None:
new_xyz = pointnet2_utils.gather_operation( new_xyz = pointnet2_utils.gather_operation(
xyz_flipped, xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint) pointnet2_utils.farthest_point_sample(xyz, self.npoint)
).transpose(1, 2).contiguous() if self.npoint is not None else None ).transpose(1, 2).contiguous() if self.npoint is not None else None
for i in range(len(self.groupers)): for i in range(len(self.groupers)):
......
...@@ -7,11 +7,11 @@ from torch.autograd import Function, Variable ...@@ -7,11 +7,11 @@ from torch.autograd import Function, Variable
from . import pointnet2_batch_cuda as pointnet2 from . import pointnet2_batch_cuda as pointnet2
class FurthestPointSampling(Function): class FarthestPointSampling(Function):
@staticmethod @staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
""" """
Uses iterative furthest point sampling to select a set of npoint features that have the largest Uses iterative farthest point sampling to select a set of npoint features that have the largest
minimum distance minimum distance
:param ctx: :param ctx:
:param xyz: (B, N, 3) where N > npoint :param xyz: (B, N, 3) where N > npoint
...@@ -25,7 +25,7 @@ class FurthestPointSampling(Function): ...@@ -25,7 +25,7 @@ class FurthestPointSampling(Function):
output = torch.cuda.IntTensor(B, npoint) output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10) temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) pointnet2.farthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output return output
@staticmethod @staticmethod
...@@ -33,7 +33,7 @@ class FurthestPointSampling(Function): ...@@ -33,7 +33,7 @@ class FurthestPointSampling(Function):
return None, None return None, None
furthest_point_sample = FurthestPointSampling.apply farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class GatherOperation(Function): class GatherOperation(Function):
......
...@@ -16,7 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -16,7 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
......
...@@ -38,13 +38,13 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, ...@@ -38,13 +38,13 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
} }
int furthest_point_sampling_wrapper(int b, int n, int m, int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
const float *points = points_tensor.data<float>(); const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>(); float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>(); int *idx = idx_tensor.data<int>();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx); farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1; return 1;
} }
...@@ -98,7 +98,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i ...@@ -98,7 +98,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
} }
template <unsigned int block_size> template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m, __global__ void farthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3) // dataset: (B, N, 3)
// tmp: (B, N) // tmp: (B, N)
...@@ -215,7 +215,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, ...@@ -215,7 +215,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
} }
} }
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs) { const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, 3) // dataset: (B, N, 3)
// tmp: (B, N) // tmp: (B, N)
...@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m, ...@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch (n_threads) { switch (n_threads) {
case 1024: case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 512: case 512:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 256: case 256:
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 128: case 128:
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 64: case 64:
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 32: case 32:
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 16: case 16:
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 8: case 8:
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 4: case 4:
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 2: case 2:
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 1: case 1:
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
default: default:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
} }
err = cudaGetLastError(); err = cudaGetLastError();
......
...@@ -20,10 +20,10 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, ...@@ -20,10 +20,10 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, float *grad_points); const float *grad_out, const int *idx, float *grad_points);
int furthest_point_sampling_wrapper(int b, int n, int m, int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs); const float *dataset, float *temp, int *idxs);
#endif #endif
...@@ -7,6 +7,26 @@ import torch.nn.functional as F ...@@ -7,6 +7,26 @@ import torch.nn.functional as F
from . import pointnet2_utils from . import pointnet2_utils
def build_local_aggregation_module(input_channels, config):
local_aggregation_name = config.get('NAME', 'StackSAModuleMSG')
if local_aggregation_name == 'StackSAModuleMSG':
mlps = config.MLPS
for k in range(len(mlps)):
mlps[k] = [input_channels] + mlps[k]
cur_layer = StackSAModuleMSG(
radii=config.POOL_RADIUS, nsamples=config.NSAMPLE, mlps=mlps, use_xyz=True, pool_method='max_pool',
)
num_c_out = sum([x[-1] for x in mlps])
elif local_aggregation_name == 'VectorPoolAggregationModuleMSG':
cur_layer = VectorPoolAggregationModuleMSG(input_channels=input_channels, config=config)
num_c_out = config.MSG_POST_MLPS[-1]
else:
raise NotImplementedError
return cur_layer, num_c_out
class StackSAModuleMSG(nn.Module): class StackSAModuleMSG(nn.Module):
def __init__(self, *, radii: List[float], nsamples: List[int], mlps: List[List[int]], def __init__(self, *, radii: List[float], nsamples: List[int], mlps: List[List[int]],
...@@ -135,3 +155,316 @@ class StackPointnetFPModule(nn.Module): ...@@ -135,3 +155,316 @@ class StackPointnetFPModule(nn.Module):
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C) new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features return new_features
class VectorPoolLocalInterpolateModule(nn.Module):
def __init__(self, mlp, num_voxels, max_neighbour_distance, nsample, neighbor_type, use_xyz=True,
neighbour_distance_multiplier=1.0, xyz_encoding_type='concat'):
"""
Args:
mlp:
num_voxels:
max_neighbour_distance:
neighbor_type: 1: ball, others: cube
nsample: find all (-1), find limited number(>0)
use_xyz:
neighbour_distance_multiplier:
xyz_encoding_type:
"""
super().__init__()
self.num_voxels = num_voxels # [num_grid_x, num_grid_y, num_grid_z]: number of grids in each local area centered at new_xyz
self.num_total_grids = self.num_voxels[0] * self.num_voxels[1] * self.num_voxels[2]
self.max_neighbour_distance = max_neighbour_distance
self.neighbor_distance_multiplier = neighbour_distance_multiplier
self.nsample = nsample
self.neighbor_type = neighbor_type
self.use_xyz = use_xyz
self.xyz_encoding_type = xyz_encoding_type
if mlp is not None:
if self.use_xyz:
mlp[0] += 9 if self.xyz_encoding_type == 'concat' else 0
shared_mlps = []
for k in range(len(mlp) - 1):
shared_mlps.extend([
nn.Conv2d(mlp[k], mlp[k + 1], kernel_size=1, bias=False),
nn.BatchNorm2d(mlp[k + 1]),
nn.ReLU()
])
self.mlp = nn.Sequential(*shared_mlps)
else:
self.mlp = None
self.num_avg_length_of_neighbor_idxs = 1000
def forward(self, support_xyz, support_features, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt):
"""
Args:
support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
support_features: (N1 + N2 ..., C) point-wise features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
new_xyz: (M1 + M2 ..., 3) centers of the ball query
new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
Returns:
new_features: (N1 + N2 ..., C_out)
"""
with torch.no_grad():
dist, idx, num_avg_length_of_neighbor_idxs = pointnet2_utils.three_nn_for_vector_pool_by_two_step(
support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt,
self.max_neighbour_distance, self.nsample, self.neighbor_type,
self.num_avg_length_of_neighbor_idxs, self.num_total_grids, self.neighbor_distance_multiplier
)
self.num_avg_length_of_neighbor_idxs = max(self.num_avg_length_of_neighbor_idxs, num_avg_length_of_neighbor_idxs.item())
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=-1, keepdim=True)
weight = dist_recip / torch.clamp_min(norm, min=1e-8)
empty_mask = (idx.view(-1, 3)[:, 0] == -1)
idx.view(-1, 3)[empty_mask] = 0
interpolated_feats = pointnet2_utils.three_interpolate(support_features, idx.view(-1, 3), weight.view(-1, 3))
interpolated_feats = interpolated_feats.view(idx.shape[0], idx.shape[1], -1) # (M1 + M2 ..., num_total_grids, C)
if self.use_xyz:
near_known_xyz = support_xyz[idx.view(-1, 3).long()].view(-1, 3, 3) # ( (M1 + M2 ...)*num_total_grids, 3)
local_xyz = (new_xyz_grid_centers.view(-1, 1, 3) - near_known_xyz).view(-1, idx.shape[1], 9)
if self.xyz_encoding_type == 'concat':
interpolated_feats = torch.cat((interpolated_feats, local_xyz), dim=-1) # ( M1 + M2 ..., num_total_grids, 9+C)
else:
raise NotImplementedError
new_features = interpolated_feats.view(-1, interpolated_feats.shape[-1]) # ((M1 + M2 ...) * num_total_grids, C)
new_features[empty_mask, :] = 0
if self.mlp is not None:
new_features = new_features.permute(1, 0)[None, :, :, None] # (1, C, N1 + N2 ..., 1)
new_features = self.mlp(new_features)
new_features = new_features.squeeze(dim=0).squeeze(dim=-1).permute(1, 0) # (N1 + N2 ..., C)
return new_features
class VectorPoolAggregationModule(nn.Module):
def __init__(
self, input_channels, num_local_voxel=(3, 3, 3), local_aggregation_type='local_interpolation',
num_reduced_channels=30, num_channels_of_local_aggregation=32, post_mlps=(128,),
max_neighbor_distance=None, neighbor_nsample=-1, neighbor_type=0, neighbor_distance_multiplier=2.0):
super().__init__()
self.num_local_voxel = num_local_voxel
self.total_voxels = self.num_local_voxel[0] * self.num_local_voxel[1] * self.num_local_voxel[2]
self.local_aggregation_type = local_aggregation_type
assert self.local_aggregation_type in ['local_interpolation', 'voxel_avg_pool', 'voxel_random_choice']
self.input_channels = input_channels
self.num_reduced_channels = input_channels if num_reduced_channels is None else num_reduced_channels
self.num_channels_of_local_aggregation = num_channels_of_local_aggregation
self.max_neighbour_distance = max_neighbor_distance
self.neighbor_nsample = neighbor_nsample
self.neighbor_type = neighbor_type # 1: ball, others: cube
if self.local_aggregation_type == 'local_interpolation':
self.local_interpolate_module = VectorPoolLocalInterpolateModule(
mlp=None, num_voxels=self.num_local_voxel,
max_neighbour_distance=self.max_neighbour_distance,
nsample=self.neighbor_nsample,
neighbor_type=self.neighbor_type,
neighbour_distance_multiplier=neighbor_distance_multiplier,
)
num_c_in = (self.num_reduced_channels + 9) * self.total_voxels
else:
self.local_interpolate_module = None
num_c_in = (self.num_reduced_channels + 3) * self.total_voxels
num_c_out = self.total_voxels * self.num_channels_of_local_aggregation
self.separate_local_aggregation_layer = nn.Sequential(
nn.Conv1d(num_c_in, num_c_out, kernel_size=1, groups=self.total_voxels, bias=False),
nn.BatchNorm1d(num_c_out),
nn.ReLU()
)
post_mlp_list = []
c_in = num_c_out
for cur_num_c in post_mlps:
post_mlp_list.extend([
nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False),
nn.BatchNorm1d(cur_num_c),
nn.ReLU()
])
c_in = cur_num_c
self.post_mlps = nn.Sequential(*post_mlp_list)
self.num_mean_points_per_grid = 20
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def extra_repr(self) -> str:
ret = f'radius={self.max_neighbour_distance}, local_voxels=({self.num_local_voxel}, ' \
f'local_aggregation_type={self.local_aggregation_type}, ' \
f'num_c_reduction={self.input_channels}->{self.num_reduced_channels}, ' \
f'num_c_local_aggregation={self.num_channels_of_local_aggregation}'
return ret
def vector_pool_with_voxel_query(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt):
use_xyz = 1
pooling_type = 0 if self.local_aggregation_type == 'voxel_avg_pool' else 1
new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid = pointnet2_utils.vector_pool_with_voxel_query_op(
xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt,
self.num_local_voxel[0], self.num_local_voxel[1], self.num_local_voxel[2],
self.max_neighbour_distance, self.num_reduced_channels, use_xyz,
self.num_mean_points_per_grid, self.neighbor_nsample, self.neighbor_type,
pooling_type
)
self.num_mean_points_per_grid = max(self.num_mean_points_per_grid, num_mean_points_per_grid.item())
num_new_pts = new_features.shape[0]
new_local_xyz = new_local_xyz.view(num_new_pts, -1, 3) # (N, num_voxel, 3)
new_features = new_features.view(num_new_pts, -1, self.num_reduced_channels) # (N, num_voxel, C)
new_features = torch.cat((new_local_xyz, new_features), dim=-1).view(num_new_pts, -1)
return new_features, point_cnt_of_grid
@staticmethod
def get_dense_voxels_by_center(point_centers, max_neighbour_distance, num_voxels):
"""
Args:
point_centers: (N, 3)
max_neighbour_distance: float
num_voxels: [num_x, num_y, num_z]
Returns:
voxel_centers: (N, total_voxels, 3)
"""
R = max_neighbour_distance
device = point_centers.device
x_grids = torch.arange(-R + R / num_voxels[0], R - R / num_voxels[0] + 1e-5, 2 * R / num_voxels[0], device=device)
y_grids = torch.arange(-R + R / num_voxels[1], R - R / num_voxels[1] + 1e-5, 2 * R / num_voxels[1], device=device)
z_grids = torch.arange(-R + R / num_voxels[2], R - R / num_voxels[2] + 1e-5, 2 * R / num_voxels[2], device=device)
x_offset, y_offset, z_offset = torch.meshgrid(x_grids, y_grids, z_grids) # shape: [num_x, num_y, num_z]
xyz_offset = torch.cat((
x_offset.contiguous().view(-1, 1),
y_offset.contiguous().view(-1, 1),
z_offset.contiguous().view(-1, 1)), dim=-1
)
voxel_centers = point_centers[:, None, :] + xyz_offset[None, :, :]
return voxel_centers
def vector_pool_with_local_interpolate(self, xyz, xyz_batch_cnt, features, new_xyz, new_xyz_batch_cnt):
"""
Args:
xyz: (N, 3)
xyz_batch_cnt: (batch_size)
features: (N, C)
new_xyz: (M, 3)
new_xyz_batch_cnt: (batch_size)
Returns:
new_features: (M, total_voxels * C)
"""
voxel_centers = self.get_dense_voxels_by_center(
point_centers=new_xyz, max_neighbour_distance=self.max_neighbour_distance, num_voxels=self.num_local_voxel
) # (M1 + M2 + ..., total_voxels, 3)
voxel_features = self.local_interpolate_module.forward(
support_xyz=xyz, support_features=features, xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz, new_xyz_grid_centers=voxel_centers, new_xyz_batch_cnt=new_xyz_batch_cnt
) # ((M1 + M2 ...) * total_voxels, C)
voxel_features = voxel_features.contiguous().view(-1, self.total_voxels * voxel_features.shape[-1])
return voxel_features
def forward(self, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt, features, **kwargs):
"""
:param xyz: (N1 + N2 ..., 3) tensor of the xyz coordinates of the features
:param xyz_batch_cnt: (batch_size), [N1, N2, ...]
:param new_xyz: (M1 + M2 ..., 3)
:param new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
:param features: (N1 + N2 ..., C) tensor of the descriptors of the the features
:return:
new_xyz: (M1 + M2 ..., 3) tensor of the new features' xyz
new_features: (M1 + M2 ..., \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
N, C = features.shape
assert C % self.num_reduced_channels == 0, \
f'the input channels ({C}) should be an integral multiple of num_reduced_channels({self.num_reduced_channels})'
features = features.view(N, -1, self.num_reduced_channels).sum(dim=1)
if self.local_aggregation_type in ['voxel_avg_pool', 'voxel_random_choice']:
vector_features, point_cnt_of_grid = self.vector_pool_with_voxel_query(
xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features,
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
)
elif self.local_aggregation_type == 'local_interpolation':
vector_features = self.vector_pool_with_local_interpolate(
xyz=xyz, xyz_batch_cnt=xyz_batch_cnt, features=features,
new_xyz=new_xyz, new_xyz_batch_cnt=new_xyz_batch_cnt
) # (M1 + M2 + ..., total_voxels * C)
else:
raise NotImplementedError
vector_features = vector_features.permute(1, 0)[None, :, :] # (1, num_voxels * C, M1 + M2 ...)
new_features = self.separate_local_aggregation_layer(vector_features)
new_features = self.post_mlps(new_features)
new_features = new_features.squeeze(dim=0).permute(1, 0)
return new_xyz, new_features
class VectorPoolAggregationModuleMSG(nn.Module):
def __init__(self, input_channels, config):
super().__init__()
self.model_cfg = config
self.num_groups = self.model_cfg.NUM_GROUPS
self.layers = []
c_in = 0
for k in range(self.num_groups):
cur_config = self.model_cfg[f'GROUP_CFG_{k}']
cur_vector_pool_module = VectorPoolAggregationModule(
input_channels=input_channels, num_local_voxel=cur_config.NUM_LOCAL_VOXEL,
post_mlps=cur_config.POST_MLPS,
max_neighbor_distance=cur_config.MAX_NEIGHBOR_DISTANCE,
neighbor_nsample=cur_config.NEIGHBOR_NSAMPLE,
local_aggregation_type=self.model_cfg.LOCAL_AGGREGATION_TYPE,
num_reduced_channels=self.model_cfg.get('NUM_REDUCED_CHANNELS', None),
num_channels_of_local_aggregation=self.model_cfg.NUM_CHANNELS_OF_LOCAL_AGGREGATION,
neighbor_distance_multiplier=2.0
)
self.__setattr__(f'layer_{k}', cur_vector_pool_module)
c_in += cur_config.POST_MLPS[-1]
c_in += 3 # use_xyz
shared_mlps = []
for cur_num_c in self.model_cfg.MSG_POST_MLPS:
shared_mlps.extend([
nn.Conv1d(c_in, cur_num_c, kernel_size=1, bias=False),
nn.BatchNorm1d(cur_num_c),
nn.ReLU()
])
c_in = cur_num_c
self.msg_post_mlps = nn.Sequential(*shared_mlps)
def forward(self, **kwargs):
features_list = []
for k in range(self.num_groups):
cur_xyz, cur_features = self.__getattr__(f'layer_{k}')(**kwargs)
features_list.append(cur_features)
features = torch.cat(features_list, dim=-1)
features = torch.cat((cur_xyz, features), dim=-1)
features = features.permute(1, 0)[None, :, :] # (1, C, N)
new_features = self.msg_post_mlps(features)
new_features = new_features.squeeze(dim=0).permute(1, 0) # (N, C)
return cur_xyz, new_features
...@@ -155,7 +155,7 @@ class QueryAndGroup(nn.Module): ...@@ -155,7 +155,7 @@ class QueryAndGroup(nn.Module):
return new_features, idx return new_features, idx
class FurthestPointSampling(Function): class FarthestPointSampling(Function):
@staticmethod @staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int): def forward(ctx, xyz: torch.Tensor, npoint: int):
""" """
...@@ -173,7 +173,7 @@ class FurthestPointSampling(Function): ...@@ -173,7 +173,7 @@ class FurthestPointSampling(Function):
output = torch.cuda.IntTensor(B, npoint) output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10) temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) pointnet2.farthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output return output
@staticmethod @staticmethod
...@@ -181,7 +181,44 @@ class FurthestPointSampling(Function): ...@@ -181,7 +181,44 @@ class FurthestPointSampling(Function):
return None, None return None, None
furthest_point_sample = FurthestPointSampling.apply farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class StackFarthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz, xyz_batch_cnt, npoint):
"""
Args:
ctx:
xyz: (N1 + N2 + ..., 3) where N > npoint
xyz_batch_cnt: [N1, N2, ...]
npoint: int, number of features in the sampled set
Returns:
output: (npoint.sum()) tensor containing the set,
npoint: (M1, M2, ...)
"""
assert xyz.is_contiguous() and xyz.shape[1] == 3
batch_size = xyz_batch_cnt.__len__()
if not isinstance(npoint, torch.Tensor):
if not isinstance(npoint, list):
npoint = [npoint for i in range(batch_size)]
npoint = torch.tensor(npoint, device=xyz.device).int()
N, _ = xyz.size()
temp = torch.cuda.FloatTensor(N).fill_(1e10)
output = torch.cuda.IntTensor(npoint.sum().item())
pointnet2.stack_farthest_point_sampling_wrapper(xyz, temp, xyz_batch_cnt, output, npoint)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
stack_farthest_point_sample = StackFarthestPointSampling.apply
class ThreeNN(Function): class ThreeNN(Function):
...@@ -262,5 +299,154 @@ class ThreeInterpolate(Function): ...@@ -262,5 +299,154 @@ class ThreeInterpolate(Function):
three_interpolate = ThreeInterpolate.apply three_interpolate = ThreeInterpolate.apply
class ThreeNNForVectorPoolByTwoStep(Function):
@staticmethod
def forward(ctx, support_xyz, xyz_batch_cnt, new_xyz, new_xyz_grid_centers, new_xyz_batch_cnt,
max_neighbour_distance, nsample, neighbor_type, avg_length_of_neighbor_idxs, num_total_grids,
neighbor_distance_multiplier):
"""
Args:
ctx:
// support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
// xyz_batch_cnt: (batch_size), [N1, N2, ...]
// new_xyz: (M1 + M2 ..., 3) centers of the ball query
// new_xyz_grid_centers: (M1 + M2 ..., num_total_grids, 3) grids centers of each grid
// new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
// nsample: find all (-1), find limited number(>0)
// neighbor_type: 1: ball, others: cube
// neighbor_distance_multiplier: query_distance = neighbor_distance_multiplier * max_neighbour_distance
Returns:
// new_xyz_grid_idxs: (M1 + M2 ..., num_total_grids, 3) three-nn
// new_xyz_grid_dist2: (M1 + M2 ..., num_total_grids, 3) square of dist of three-nn
"""
num_new_xyz = new_xyz.shape[0]
new_xyz_grid_dist2 = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape)
new_xyz_grid_idxs = new_xyz_grid_centers.new_zeros(new_xyz_grid_centers.shape).int().fill_(-1)
while True:
num_max_sum_points = avg_length_of_neighbor_idxs * num_new_xyz
stack_neighbor_idxs = new_xyz_grid_idxs.new_zeros(num_max_sum_points)
start_len = new_xyz_grid_idxs.new_zeros(num_new_xyz, 2).int()
cumsum = new_xyz_grid_idxs.new_zeros(1)
pointnet2.query_stacked_local_neighbor_idxs_wrapper_stack(
support_xyz.contiguous(), xyz_batch_cnt.contiguous(),
new_xyz.contiguous(), new_xyz_batch_cnt.contiguous(),
stack_neighbor_idxs.contiguous(), start_len.contiguous(), cumsum,
avg_length_of_neighbor_idxs, max_neighbour_distance * neighbor_distance_multiplier,
nsample, neighbor_type
)
avg_length_of_neighbor_idxs = cumsum[0].item() // num_new_xyz + int(cumsum[0].item() % num_new_xyz > 0)
if cumsum[0] <= num_max_sum_points:
break
stack_neighbor_idxs = stack_neighbor_idxs[:cumsum[0]]
pointnet2.query_three_nn_by_stacked_local_idxs_wrapper_stack(
support_xyz, new_xyz, new_xyz_grid_centers, new_xyz_grid_idxs, new_xyz_grid_dist2,
stack_neighbor_idxs, start_len, num_new_xyz, num_total_grids
)
return torch.sqrt(new_xyz_grid_dist2), new_xyz_grid_idxs, torch.tensor(avg_length_of_neighbor_idxs)
three_nn_for_vector_pool_by_two_step = ThreeNNForVectorPoolByTwoStep.apply
class VectorPoolWithVoxelQuery(Function):
@staticmethod
def forward(ctx, support_xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor, support_features: torch.Tensor,
new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor, num_grid_x, num_grid_y, num_grid_z,
max_neighbour_distance, num_c_out_each_grid, use_xyz,
num_mean_points_per_grid=100, nsample=-1, neighbor_type=0, pooling_type=0):
"""
Args:
ctx:
support_xyz: (N1 + N2 ..., 3) xyz coordinates of the features
xyz_batch_cnt: (batch_size), [N1, N2, ...]
support_features: (N1 + N2 ..., C)
new_xyz: (M1 + M2 ..., 3) centers of new positions
new_xyz_batch_cnt: (batch_size), [M1, M2, ...]
num_grid_x: number of grids in each local area centered at new_xyz
num_grid_y:
num_grid_z:
max_neighbour_distance:
num_c_out_each_grid:
use_xyz:
neighbor_type: 1: ball, others: cube:
pooling_type: 0: avg_pool, 1: random choice
Returns:
new_features: (M1 + M2 ..., num_c_out)
"""
assert support_xyz.is_contiguous()
assert support_features.is_contiguous()
assert xyz_batch_cnt.is_contiguous()
assert new_xyz.is_contiguous()
assert new_xyz_batch_cnt.is_contiguous()
num_total_grids = num_grid_x * num_grid_y * num_grid_z
num_c_out = num_c_out_each_grid * num_total_grids
N, num_c_in = support_features.shape
M = new_xyz.shape[0]
assert num_c_in % num_c_out_each_grid == 0, \
f'the input channels ({num_c_in}) should be an integral multiple of num_c_out_each_grid({num_c_out_each_grid})'
while True:
new_features = support_features.new_zeros((M, num_c_out))
new_local_xyz = support_features.new_zeros((M, 3 * num_total_grids))
point_cnt_of_grid = xyz_batch_cnt.new_zeros((M, num_total_grids))
num_max_sum_points = num_mean_points_per_grid * M
grouped_idxs = xyz_batch_cnt.new_zeros((num_max_sum_points, 3))
num_cum_sum = pointnet2.vector_pool_wrapper(
support_xyz, xyz_batch_cnt, support_features, new_xyz, new_xyz_batch_cnt,
new_features, new_local_xyz, point_cnt_of_grid, grouped_idxs,
num_grid_x, num_grid_y, num_grid_z, max_neighbour_distance, use_xyz,
num_max_sum_points, nsample, neighbor_type, pooling_type
)
num_mean_points_per_grid = num_cum_sum // M + int(num_cum_sum % M > 0)
if num_cum_sum <= num_max_sum_points:
break
grouped_idxs = grouped_idxs[:num_cum_sum]
normalizer = torch.clamp_min(point_cnt_of_grid[:, :, None].float(), min=1e-6)
new_features = (new_features.view(-1, num_total_grids, num_c_out_each_grid) / normalizer).view(-1, num_c_out)
if use_xyz:
new_local_xyz = (new_local_xyz.view(-1, num_total_grids, 3) / normalizer).view(-1, num_total_grids * 3)
num_mean_points_per_grid = torch.Tensor([num_mean_points_per_grid]).int()
nsample = torch.Tensor([nsample]).int()
ctx.vector_pool_for_backward = (point_cnt_of_grid, grouped_idxs, N, num_c_in)
ctx.mark_non_differentiable(new_local_xyz, num_mean_points_per_grid, nsample, point_cnt_of_grid)
return new_features, new_local_xyz, num_mean_points_per_grid, point_cnt_of_grid
@staticmethod
def backward(ctx, grad_new_features: torch.Tensor, grad_local_xyz: torch.Tensor, grad_num_cum_sum, grad_point_cnt_of_grid):
"""
Args:
ctx:
grad_new_features: (M1 + M2 ..., num_c_out), num_c_out = num_c_out_each_grid * num_total_grids
Returns:
grad_support_features: (N1 + N2 ..., C_in)
"""
point_cnt_of_grid, grouped_idxs, N, num_c_in = ctx.vector_pool_for_backward
grad_support_features = grad_new_features.new_zeros((N, num_c_in))
pointnet2.vector_pool_grad_wrapper(
grad_new_features.contiguous(), point_cnt_of_grid, grouped_idxs,
grad_support_features
)
return None, None, grad_support_features, None, None, None, None, None, None, None, None, None, None, None, None
vector_pool_with_voxel_query_op = VectorPoolWithVoxelQuery.apply
if __name__ == '__main__': if __name__ == '__main__':
pass pass
...@@ -6,13 +6,15 @@ ...@@ -6,13 +6,15 @@
#include "sampling_gpu.h" #include "sampling_gpu.h"
#include "interpolate_gpu.h" #include "interpolate_gpu.h"
#include "voxel_query_gpu.h" #include "voxel_query_gpu.h"
#include "vector_pool_gpu.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack"); m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack");
m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack"); m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("stack_farthest_point_sampling_wrapper", &stack_farthest_point_sampling_wrapper, "stack_farthest_point_sampling_wrapper");
m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack"); m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack");
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack"); m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack");
...@@ -20,4 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -20,4 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack"); m.def("three_nn_wrapper", &three_nn_wrapper_stack, "three_nn_wrapper_stack");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack"); m.def("three_interpolate_wrapper", &three_interpolate_wrapper_stack, "three_interpolate_wrapper_stack");
m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack"); m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_stack, "three_interpolate_grad_wrapper_stack");
m.def("query_stacked_local_neighbor_idxs_wrapper_stack", &query_stacked_local_neighbor_idxs_wrapper_stack, "query_stacked_local_neighbor_idxs_wrapper_stack");
m.def("query_three_nn_by_stacked_local_idxs_wrapper_stack", &query_three_nn_by_stacked_local_idxs_wrapper_stack, "query_three_nn_by_stacked_local_idxs_wrapper_stack");
m.def("vector_pool_wrapper", &vector_pool_wrapper_stack, "vector_pool_grad_wrapper_stack");
m.def("vector_pool_grad_wrapper", &vector_pool_grad_wrapper_stack, "vector_pool_grad_wrapper_stack");
} }
...@@ -21,7 +21,7 @@ extern THCState *state; ...@@ -21,7 +21,7 @@ extern THCState *state;
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int furthest_point_sampling_wrapper(int b, int n, int m, int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(points_tensor); CHECK_INPUT(points_tensor);
...@@ -32,6 +32,29 @@ int furthest_point_sampling_wrapper(int b, int n, int m, ...@@ -32,6 +32,29 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
float *temp = temp_tensor.data<float>(); float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>(); int *idx = idx_tensor.data<int>();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx); farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1; return 1;
} }
int stack_farthest_point_sampling_wrapper(at::Tensor points_tensor,
at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor, at::Tensor idx_tensor,
at::Tensor num_sampled_points_tensor) {
CHECK_INPUT(points_tensor);
CHECK_INPUT(temp_tensor);
CHECK_INPUT(idx_tensor);
CHECK_INPUT(xyz_batch_cnt_tensor);
CHECK_INPUT(num_sampled_points_tensor);
int batch_size = xyz_batch_cnt_tensor.size(0);
int N = points_tensor.size(0);
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *xyz_batch_cnt = xyz_batch_cnt_tensor.data<int>();
int *idx = idx_tensor.data<int>();
int *num_sampled_points = num_sampled_points_tensor.data<int>();
stack_farthest_point_sampling_kernel_launcher(N, batch_size, points, temp, xyz_batch_cnt, idx, num_sampled_points);
return 1;
}
\ No newline at end of file
...@@ -22,7 +22,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i ...@@ -22,7 +22,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
template <unsigned int block_size> template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m, __global__ void farthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3) // dataset: (B, N, 3)
// tmp: (B, N) // tmp: (B, N)
...@@ -139,7 +139,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, ...@@ -139,7 +139,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
} }
} }
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs) { const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, 3) // dataset: (B, N, 3)
// tmp: (B, N) // tmp: (B, N)
...@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m, ...@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch (n_threads) { switch (n_threads) {
case 1024: case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 512: case 512:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 256: case 256:
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 128: case 128:
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 64: case 64:
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 32: case 32:
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 16: case 16:
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 8: case 8:
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 4: case 4:
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 2: case 2:
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 1: case 1:
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break; farthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
default: default:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
} }
err = cudaGetLastError(); err = cudaGetLastError();
...@@ -182,3 +182,168 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m, ...@@ -182,3 +182,168 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
exit(-1); exit(-1);
} }
} }
template <unsigned int block_size>
__global__ void stack_farthest_point_sampling_kernel(int batch_size, int N,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points) {
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// num_sampled_points: [M1, M2, ...] int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int bs_idx = blockIdx.x;
int xyz_batch_start_idx = 0, idxs_start_idx = 0;
for (int k = 0; k < bs_idx; k++){
xyz_batch_start_idx += xyz_batch_cnt[k];
idxs_start_idx += num_sampled_points[k];
}
dataset += xyz_batch_start_idx * 3;
temp += xyz_batch_start_idx;
idxs += idxs_start_idx;
int n = xyz_batch_cnt[bs_idx];
int m = num_sampled_points[bs_idx];
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0) idxs[0] = xyz_batch_start_idx;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
}
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0)
idxs[j] = old + xyz_batch_start_idx;
}
}
void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points) {
// """
// Args:
// ctx:
// dataset: (N1 + N2 + ..., 3) where N > npoint
// temp: (N1 + N2 + ...) where N > npoint
// xyz_batch_cnt: [N1, N2, ...]
// npoint: int, number of features in the sampled set
// Returns:
// idxs: (npoint.sum()) tensor containing the set,
// npoint: (M1, M2, ...)
// """
cudaError_t err;
unsigned int n_threads = opt_n_threads(N);
stack_farthest_point_sampling_kernel<1024><<<batch_size, 1024>>>(
batch_size, N, dataset, temp, xyz_batch_cnt, idxs, num_sampled_points
);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
\ No newline at end of file
...@@ -6,10 +6,18 @@ ...@@ -6,10 +6,18 @@
#include<vector> #include<vector>
int furthest_point_sampling_wrapper(int b, int n, int m, int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m, void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs); const float *dataset, float *temp, int *idxs);
int stack_farthest_point_sampling_wrapper(
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor xyz_batch_cnt_tensor,
at::Tensor idx_tensor, at::Tensor num_sampled_points_tensor);
void stack_farthest_point_sampling_kernel_launcher(int N, int batch_size,
const float *dataset, float *temp, int *xyz_batch_cnt, int *idxs, int *num_sampled_points);
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment