Unverified Commit 4c8bfb48 authored by dingchang's avatar dingchang Committed by GitHub
Browse files

[Feature] Add roiaware pool3d ops from mmdet3d (#1382)



* add ops (roiaware pool3d) in mmdet3d

* refactor code

* fix typo
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent e3e1dba2
...@@ -23,6 +23,7 @@ We implement common CUDA ops used in detection, segmentation, etc. ...@@ -23,6 +23,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- RoIPointPool3d - RoIPointPool3d
- RoIPool - RoIPool
- RoIAlign - RoIAlign
- RoIAwarePool3d
- SimpleRoIAlign - SimpleRoIAlign
- SigmoidFocalLoss - SigmoidFocalLoss
- SoftmaxFocalLoss - SoftmaxFocalLoss
......
...@@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 ...@@ -23,6 +23,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- RoIPointPool3d - RoIPointPool3d
- RoIPool - RoIPool
- RoIAlign - RoIAlign
- RoIAwarePool3d
- SimpleRoIAlign - SimpleRoIAlign
- SigmoidFocalLoss - SigmoidFocalLoss
- SoftmaxFocalLoss - SoftmaxFocalLoss
......
...@@ -34,11 +34,14 @@ from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms ...@@ -34,11 +34,14 @@ from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .pixel_group import pixel_group from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample, from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point) rel_roi_point_to_rel_img_point)
from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
from .points_sampler import PointsSampler from .points_sampler import PointsSampler
from .psa_mask import PSAMask from .psa_mask import PSAMask
from .roi_align import RoIAlign, roi_align from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_pool import RoIPool, roi_pool from .roi_pool import RoIPool, roi_pool
from .roiaware_pool3d import RoIAwarePool3d
from .roipoint_pool3d import RoIPointPool3d from .roipoint_pool3d import RoIPointPool3d
from .saconv import SAConv2d from .saconv import SAConv2d
from .scatter_points import DynamicScatter, dynamic_scatter from .scatter_points import DynamicScatter, dynamic_scatter
...@@ -50,24 +53,82 @@ from .upfirdn2d import upfirdn2d ...@@ -50,24 +53,82 @@ from .upfirdn2d import upfirdn2d
from .voxelize import Voxelization, voxelization from .voxelize import Voxelization, voxelization
__all__ = [ __all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', 'bbox_overlaps',
'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack', 'CARAFE',
'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack', 'CARAFENaive',
'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss', 'CARAFEPack',
'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss', 'carafe',
'get_compiler_version', 'get_compiling_cuda_version', 'carafe_naive',
'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d', 'CornerPool',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', 'DeformConv2d',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', 'DeformConv2dPack',
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', 'deform_conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', 'DeformRoIPool',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'DeformRoIPoolPack',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', 'ModulatedDeformRoIPoolPack',
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', 'deform_roi_pool',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'SigmoidFocalLoss',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', 'SoftmaxFocalLoss',
'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention', 'sigmoid_focal_loss',
'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'softmax_focal_loss',
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample', 'get_compiler_version',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation' 'get_compiling_cuda_version',
'get_onnxruntime_op_path',
'MaskedConv2d',
'masked_conv2d',
'ModulatedDeformConv2d',
'ModulatedDeformConv2dPack',
'modulated_deform_conv2d',
'batched_nms',
'nms',
'soft_nms',
'nms_match',
'RoIAlign',
'roi_align',
'RoIPool',
'roi_pool',
'SyncBatchNorm',
'Conv2d',
'ConvTranspose2d',
'Linear',
'MaxPool2d',
'CrissCrossAttention',
'PSAMask',
'point_sample',
'rel_roi_point_to_rel_img_point',
'SimpleRoIAlign',
'SAConv2d',
'TINShift',
'tin_shift',
'assign_score_withk',
'box_iou_rotated',
'RoIPointPool3d',
'nms_rotated',
'knn',
'ball_query',
'upfirdn2d',
'FusedBiasLeakyReLU',
'fused_bias_leakyrelu',
'RoIAlignRotated',
'roi_align_rotated',
'pixel_group',
'contour_expand',
'three_nn',
'three_interpolate',
'MultiScaleDeformableAttention',
'Voxelization',
'voxelization',
'dynamic_scatter',
'DynamicScatter',
'BorderAlign',
'border_align',
'gather_points',
'furthest_point_sample',
'furthest_point_sample_with_dist',
'PointsSampler',
'Correlation',
'RoIAwarePool3d',
'points_in_boxes_part',
'points_in_boxes_cpu',
'points_in_boxes_all',
] ]
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINT_IN_BOXES_CUDA_KERNEL_CUH
#define POINT_IN_BOXES_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
// cz in the bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size /
2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}
template <typename T>
__global__ void points_in_boxes_part_forward_cuda_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;
boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num + pt_idx;
T local_x = 0, local_y = 0;
int cur_in_flag = 0;
for (int k = 0; k < boxes_num; k++) {
cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[0] = k;
break;
}
}
}
template <typename T>
__global__ void points_in_boxes_all_forward_cuda_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= batch_size || pt_idx >= pts_num) return;
boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num;
T local_x = 0, local_y = 0;
for (int k = 0; k < boxes_num; k++) {
const int cur_in_flag =
check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[k] = 1;
}
}
}
#endif // POINT_IN_BOXES_CUDA_KERNEL_CUH
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ROIAWARE_POOL3D_CUDA_KERNEL_CUH
#define ROIAWARE_POOL3D_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
// cz in the bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size /
2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}
template <typename T>
__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num,
int out_x, int out_y, int out_z,
const T *rois, const T *pts,
int *pts_mask) {
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N,
// npoints): -1 means point does not in this box, otherwise: encode (x_idxs,
// y_idxs, z_idxs) by binary bit
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
int box_idx = blockIdx.y;
if (pt_idx >= pts_num || box_idx >= boxes_num) return;
pts += pt_idx * 3;
rois += box_idx * 7;
pts_mask += box_idx * pts_num + pt_idx;
T local_x = 0, local_y = 0;
int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y);
pts_mask[0] = -1;
if (cur_in_flag > 0) {
T local_z = pts[2] - rois[2];
T x_size = rois[3], y_size = rois[4], z_size = rois[5];
T x_res = x_size / out_x;
T y_res = y_size / out_y;
T z_res = z_size / out_z;
unsigned int x_idx = int((local_x + x_size / 2) / x_res);
unsigned int y_idx = int((local_y + y_size / 2) / y_res);
unsigned int z_idx = int(local_z / z_res);
x_idx = min(max(x_idx, 0), out_x - 1);
y_idx = min(max(y_idx, 0), out_y - 1);
z_idx = min(max(z_idx, 0), out_z - 1);
unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx;
pts_mask[0] = idx_encoding;
}
}
template <typename T>
__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num,
int max_pts_each_voxel, int out_x,
int out_y, int out_z,
const int *pts_mask,
T *pts_idx_of_voxels) {
// params pts_mask: (N, npoints) 0 or 1
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
int box_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (box_idx >= boxes_num) return;
int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel;
for (int k = 0; k < pts_num; k++) {
if (pts_mask[box_idx * pts_num + k] != -1) {
unsigned int idx_encoding = pts_mask[box_idx * pts_num + k];
unsigned int x_idx = (idx_encoding >> 16) & 0xFF;
unsigned int y_idx = (idx_encoding >> 8) & 0xFF;
unsigned int z_idx = idx_encoding & 0xFF;
unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel +
y_idx * out_z * max_pts_each_voxel +
z_idx * max_pts_each_voxel;
unsigned int cnt = pts_idx_of_voxels[base_offset];
if (cnt < max_num_pts) {
pts_idx_of_voxels[base_offset + cnt + 1] = k;
pts_idx_of_voxels[base_offset]++;
}
}
}
}
template <typename T>
__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const T *pts_feature,
const int *pts_idx_of_voxels,
T *pooled_features, int *argmax) {
// params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
// index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C)
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
pooled_features += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
argmax += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
int argmax_idx = -1;
float max_val = -1e50;
int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++) {
if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > max_val) {
max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
argmax_idx = pts_idx_of_voxels[k];
}
}
if (argmax_idx != -1) {
pooled_features[0] = max_val;
}
argmax[0] = argmax_idx;
}
template <typename T>
__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const T *pts_feature,
const int *pts_idx_of_voxels,
T *pooled_features) {
// params pts_feature: (npoints, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel),
// index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C)
// params argmax: (N, out_x, out_y, out_z, C)
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
pooled_features += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
float sum_val = 0;
int total_pts = pts_idx_of_voxels[0];
for (int k = 1; k <= total_pts; k++) {
sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx];
}
if (total_pts > 0) {
pooled_features[0] = sum_val / total_pts;
}
}
template <typename T>
__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels,
int out_x, int out_y, int out_z,
const int *argmax,
const T *grad_out, T *grad_in) {
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
argmax += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
grad_out += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
if (argmax[0] == -1) return;
atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1);
}
template <typename T>
__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels,
int out_x, int out_y, int out_z,
int max_pts_each_voxel,
const int *pts_idx_of_voxels,
const T *grad_out, T *grad_in) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
int box_idx = blockIdx.z;
int channel_idx = blockIdx.y;
int voxel_idx_flat = blockIdx.x * blockDim.x + threadIdx.x;
int x_idx = voxel_idx_flat / (out_y * out_z);
int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z;
int z_idx = voxel_idx_flat % out_z;
if (box_idx >= boxes_num || channel_idx >= channels || x_idx >= out_x ||
y_idx >= out_y || z_idx >= out_z)
return;
int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx;
pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel +
offset_base * max_pts_each_voxel;
grad_out += box_idx * out_x * out_y * out_z * channels +
offset_base * channels + channel_idx;
int total_pts = pts_idx_of_voxels[0];
float cur_grad = 1 / fmaxf(float(total_pts), 1.0);
for (int k = 1; k <= total_pts; k++) {
atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx,
grad_out[0] * cur_grad);
}
}
#endif // ROIAWARE_POOL3D_CUDA_KERNEL_CUH
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
#include <stdio.h>
#include "points_in_boxes_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is
// the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x,
// y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default
// -1
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
boxes.scalar_type(), "points_in_boxes_part_forward_cuda_kernel", [&] {
points_in_boxes_part_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
batch_size, boxes_num, pts_num, boxes.data_ptr<scalar_t>(),
pts.data_ptr<scalar_t>(), box_idx_of_points.data_ptr<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box params pts: (B, npoints, 3)
// [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints),
// default -1
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), batch_size);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
boxes.scalar_type(), "points_in_boxes_all_forward_cuda_kernel", [&] {
points_in_boxes_all_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
batch_size, boxes_num, pts_num, boxes.data_ptr<scalar_t>(),
pts.data_ptr<scalar_t>(), box_idx_of_points.data_ptr<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
// Modified from
// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu
// Written by Shaoshuai Shi
// All Rights Reserved 2019.
#include <stdio.h>
#include "pytorch_cuda_helper.hpp"
#include "roiaware_pool3d_cuda_kernel.cuh"
void RoiawarePool3dForwardCUDAKernelLauncher(
int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
int out_y, int out_z, const Tensor rois, const Tensor pts,
const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method) {
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate params pts: (npoints, 3) [x, y, z] in LiDAR coordinate params
// pts_feature: (npoints, C) params argmax: (N, out_x, out_y, out_z, C) params
// pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) params
// pooled_features: (N, out_x, out_y, out_z, C) params pool_method: 0:
// max_pool 1: avg_pool
at::cuda::CUDAGuard device_guard(pts_feature.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Tensor pts_mask =
-at::ones({boxes_num, pts_num}, pts_feature.options().dtype(at::kInt));
dim3 blocks_mask(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rois.scalar_type(), "generate_pts_mask_for_box3d", [&] {
generate_pts_mask_for_box3d<scalar_t>
<<<blocks_mask, threads, 0, stream>>>(
boxes_num, pts_num, out_x, out_y, out_z,
rois.data_ptr<scalar_t>(), pts.data_ptr<scalar_t>(),
pts_mask.data_ptr<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
// TODO: Merge the collect and pool functions, SS
dim3 blocks_collect(DIVUP(boxes_num, THREADS_PER_BLOCK));
AT_DISPATCH_INTEGRAL_TYPES(
pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] {
collect_inside_pts_for_box3d<scalar_t>
<<<blocks_collect, threads, 0, stream>>>(
boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z,
pts_mask.data_ptr<int>(),
pts_idx_of_voxels.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
dim3 blocks_pool(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
boxes_num);
if (pool_method == 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pts_feature.scalar_type(), "roiaware_maxpool3d", [&] {
roiaware_maxpool3d<scalar_t><<<blocks_pool, threads, 0, stream>>>(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y,
out_z, pts_feature.data_ptr<scalar_t>(),
pts_idx_of_voxels.data_ptr<int>(),
pooled_features.data_ptr<scalar_t>(), argmax.data_ptr<int>());
});
} else if (pool_method == 1) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pts_feature.scalar_type(), "roiaware_avgpool3d", [&] {
roiaware_avgpool3d<scalar_t><<<blocks_pool, threads, 0, stream>>>(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y,
out_z, pts_feature.data_ptr<scalar_t>(),
pts_idx_of_voxels.data_ptr<int>(),
pooled_features.data_ptr<scalar_t>());
});
}
AT_CUDA_CHECK(cudaGetLastError());
}
void RoiawarePool3dBackwardCUDAKernelLauncher(
int boxes_num, int out_x, int out_y, int out_z, int channels,
int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax,
const Tensor grad_out, Tensor grad_in, int pool_method) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
// params pool_method: 0: max_pool, 1: avg_pool
at::cuda::CUDAGuard device_guard(grad_out.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(out_x * out_y * out_z, THREADS_PER_BLOCK), channels,
boxes_num);
dim3 threads(THREADS_PER_BLOCK);
if (pool_method == 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_in.scalar_type(), "roiaware_maxpool3d_backward", [&] {
roiaware_maxpool3d_backward<scalar_t><<<blocks, threads, 0, stream>>>(
boxes_num, channels, out_x, out_y, out_z, argmax.data_ptr<int>(),
grad_out.data_ptr<scalar_t>(), grad_in.data_ptr<scalar_t>());
});
} else if (pool_method == 1) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_in.scalar_type(), "roiaware_avgpool3d_backward", [&] {
roiaware_avgpool3d_backward<scalar_t><<<blocks, threads, 0, stream>>>(
boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel,
pts_idx_of_voxels.data_ptr<int>(), grad_out.data_ptr<scalar_t>(),
grad_in.data_ptr<scalar_t>());
});
}
AT_CUDA_CHECK(cudaGetLastError());
}
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void PointsInBoxesPartForwardCUDAKernelLauncher(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points);
void points_in_boxes_part_forward_cuda(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
PointsInBoxesPartForwardCUDAKernelLauncher(batch_size, boxes_num, pts_num,
boxes, pts, box_idx_of_points);
};
void PointsInBoxesAllForwardCUDAKernelLauncher(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points);
void points_in_boxes_all_forward_cuda(int batch_size, int boxes_num,
int pts_num, const Tensor boxes,
const Tensor pts,
Tensor box_idx_of_points) {
PointsInBoxesAllForwardCUDAKernelLauncher(batch_size, boxes_num, pts_num,
boxes, pts, box_idx_of_points);
};
#endif
void points_in_boxes_part_forward(Tensor boxes_tensor, Tensor pts_tensor,
Tensor box_idx_of_points_tensor) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box params pts: (B, npoints, 3)
// [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints),
// default -1
if (pts_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(boxes_tensor);
CHECK_CUDA_INPUT(pts_tensor);
CHECK_CUDA_INPUT(box_idx_of_points_tensor);
int batch_size = boxes_tensor.size(0);
int boxes_num = boxes_tensor.size(1);
int pts_num = pts_tensor.size(1);
const float *boxes = boxes_tensor.data_ptr<float>();
const float *pts = pts_tensor.data_ptr<float>();
int *box_idx_of_points = box_idx_of_points_tensor.data_ptr<int>();
points_in_boxes_part_forward_cuda(batch_size, boxes_num, pts_num,
boxes_tensor, pts_tensor,
box_idx_of_points_tensor);
#else
AT_ERROR("points_in_boxes_part is not compiled with GPU support");
#endif
} else {
AT_ERROR("points_in_boxes_part is not implemented on CPU");
}
}
void points_in_boxes_all_forward(Tensor boxes_tensor, Tensor pts_tensor,
Tensor box_idx_of_points_tensor) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center. params pts: (B, npoints, 3) [x, y, z]
// in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default -1
if (pts_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(boxes_tensor);
CHECK_CUDA_INPUT(pts_tensor);
CHECK_CUDA_INPUT(box_idx_of_points_tensor);
int batch_size = boxes_tensor.size(0);
int boxes_num = boxes_tensor.size(1);
int pts_num = pts_tensor.size(1);
const float *boxes = boxes_tensor.data_ptr<float>();
const float *pts = pts_tensor.data_ptr<float>();
int *box_idx_of_points = box_idx_of_points_tensor.data_ptr<int>();
points_in_boxes_all_forward_cuda(batch_size, boxes_num, pts_num,
boxes_tensor, pts_tensor,
box_idx_of_points_tensor);
#else
AT_ERROR("points_in_boxes_all is not compiled with GPU support");
#endif
} else {
AT_ERROR("points_in_boxes_all is not implemented on CPU");
}
}
#include "pytorch_cpp_helper.hpp"
inline void lidar_to_local_coords_cpu(float shift_x, float shift_y, float rz,
float &local_x, float &local_y) {
float cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}
inline int check_pt_in_box3d_cpu(const float *pt, const float *box3d,
float &local_x, float &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
// cz in the bottom center
float x = pt[0], y = pt[1], z = pt[2];
float cx = box3d[0], cy = box3d[1], cz = box3d[2];
float x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size /
2.0; // shift to the center since cz in box3d is the bottom center
if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords_cpu(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}
void points_in_boxes_cpu_forward(Tensor boxes_tensor, Tensor pts_tensor,
Tensor pts_indices_tensor) {
// params boxes: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (npoints, 3) [x, y, z] in LiDAR coordinate params pts_indices: (N, npoints)
CHECK_CONTIGUOUS(boxes_tensor);
CHECK_CONTIGUOUS(pts_tensor);
CHECK_CONTIGUOUS(pts_indices_tensor);
int boxes_num = boxes_tensor.size(0);
int pts_num = pts_tensor.size(0);
const float *boxes = boxes_tensor.data_ptr<float>();
const float *pts = pts_tensor.data_ptr<float>();
int *pts_indices = pts_indices_tensor.data_ptr<int>();
float local_x = 0, local_y = 0;
for (int i = 0; i < boxes_num; i++) {
for (int j = 0; j < pts_num; j++) {
int cur_in_flag =
check_pt_in_box3d_cpu(pts + j * 3, boxes + i * 7, local_x, local_y);
pts_indices[i * pts_num + j] = cur_in_flag;
}
}
}
...@@ -296,6 +296,22 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes, ...@@ -296,6 +296,22 @@ void border_align_backward(const Tensor &grad_output, const Tensor &boxes,
const Tensor &argmax_idx, Tensor grad_input, const Tensor &argmax_idx, Tensor grad_input,
const int pool_size); const int pool_size);
void points_in_boxes_cpu_forward(Tensor boxes_tensor, Tensor pts_tensor,
Tensor pts_indices_tensor);
void points_in_boxes_part_forward(Tensor boxes_tensor, Tensor pts_tensor,
Tensor box_idx_of_points_tensor);
void points_in_boxes_all_forward(Tensor boxes_tensor, Tensor pts_tensor,
Tensor box_idx_of_points_tensor);
void roiaware_pool3d_forward(Tensor rois, Tensor pts, Tensor pts_feature,
Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method);
void roiaware_pool3d_backward(Tensor pts_idx_of_voxels, Tensor argmax,
Tensor grad_out, Tensor grad_in, int pool_method);
void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH, void correlation_forward(Tensor input1, Tensor input2, Tensor output, int kH,
int kW, int patchH, int patchW, int padH, int padW, int kW, int patchH, int patchW, int padH, int padW,
int dilationH, int dilationW, int dilation_patchH, int dilationH, int dilationW, int dilation_patchH,
...@@ -599,6 +615,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -599,6 +615,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"backward function of border_align", py::arg("grad_output"), "backward function of border_align", py::arg("grad_output"),
py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"), py::arg("boxes"), py::arg("argmax_idx"), py::arg("grad_input"),
py::arg("pool_size")); py::arg("pool_size"));
m.def("points_in_boxes_cpu_forward", &points_in_boxes_cpu_forward,
"points_in_boxes_cpu_forward", py::arg("boxes_tensor"),
py::arg("pts_tensor"), py::arg("pts_indices_tensor"));
m.def("points_in_boxes_part_forward", &points_in_boxes_part_forward,
"points_in_boxes_part_forward", py::arg("boxes_tensor"),
py::arg("pts_tensor"), py::arg("box_idx_of_points_tensor"));
m.def("points_in_boxes_all_forward", &points_in_boxes_all_forward,
"points_in_boxes_all_forward", py::arg("boxes_tensor"),
py::arg("pts_tensor"), py::arg("box_idx_of_points_tensor"));
m.def("roiaware_pool3d_forward", &roiaware_pool3d_forward,
"roiaware_pool3d_forward", py::arg("rois"), py::arg("pts"),
py::arg("pts_feature"), py::arg("argmax"), py::arg("pts_idx_of_voxels"),
py::arg("pooled_features"), py::arg("pool_method"));
m.def("roiaware_pool3d_backward", &roiaware_pool3d_backward,
"roiaware_pool3d_backward", py::arg("pts_idx_of_voxels"),
py::arg("argmax"), py::arg("grad_out"), py::arg("grad_in"),
py::arg("pool_method"));
m.def("correlation_forward", &correlation_forward, "Correlation forward"); m.def("correlation_forward", &correlation_forward, "Correlation forward");
m.def("correlation_backward", &correlation_backward, "Correlation backward"); m.def("correlation_backward", &correlation_backward, "Correlation backward");
} }
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void RoiawarePool3dForwardCUDAKernelLauncher(
int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x,
int out_y, int out_z, const Tensor rois, const Tensor pts,
const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method);
void roiaware_pool3d_forward_cuda(int boxes_num, int pts_num, int channels,
int max_pts_each_voxel, int out_x, int out_y,
int out_z, const Tensor rois,
const Tensor pts, const Tensor pts_feature,
Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method) {
RoiawarePool3dForwardCUDAKernelLauncher(
boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z,
rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features,
pool_method);
};
void RoiawarePool3dBackwardCUDAKernelLauncher(
int boxes_num, int out_x, int out_y, int out_z, int channels,
int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax,
const Tensor grad_out, Tensor grad_in, int pool_method);
void roiaware_pool3d_backward_cuda(int boxes_num, int out_x, int out_y,
int out_z, int channels,
int max_pts_each_voxel,
const Tensor pts_idx_of_voxels,
const Tensor argmax, const Tensor grad_out,
Tensor grad_in, int pool_method) {
RoiawarePool3dBackwardCUDAKernelLauncher(
boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel,
pts_idx_of_voxels, argmax, grad_out, grad_in, pool_method);
};
#endif
void roiaware_pool3d_forward(Tensor rois, Tensor pts, Tensor pts_feature,
Tensor argmax, Tensor pts_idx_of_voxels,
Tensor pooled_features, int pool_method) {
// params rois: (N, 7) [x, y, z, x_size, y_size, z_size, ry] in LiDAR
// coordinate
// params pts: (npoints, 3) [x, y, z] in LiDAR coordinate
// params pts_feature: (npoints, C)
// params argmax: (N, out_x, out_y, out_z, C)
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params pooled_features: (N, out_x, out_y, out_z, C)
// params pool_method: 0: max_pool 1: avg_pool
if (pts.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(rois);
CHECK_CUDA_INPUT(pts);
CHECK_CUDA_INPUT(pts_feature);
CHECK_CUDA_INPUT(argmax);
CHECK_CUDA_INPUT(pts_idx_of_voxels);
CHECK_CUDA_INPUT(pooled_features);
int boxes_num = rois.size(0);
int pts_num = pts.size(0);
int channels = pts_feature.size(1);
int max_pts_each_voxel =
pts_idx_of_voxels.size(4); // index 0 is the counter
int out_x = pts_idx_of_voxels.size(1);
int out_y = pts_idx_of_voxels.size(2);
int out_z = pts_idx_of_voxels.size(3);
assert((out_x < 256) && (out_y < 256) &&
(out_z < 256)); // we encode index with 8bit
roiaware_pool3d_forward_cuda(boxes_num, pts_num, channels,
max_pts_each_voxel, out_x, out_y, out_z, rois,
pts, pts_feature, argmax, pts_idx_of_voxels,
pooled_features, pool_method);
#else
AT_ERROR("roiaware_pool3d is not compiled with GPU support");
#endif
} else {
AT_ERROR("roiaware_pool3d is not implemented on CPU");
}
}
void roiaware_pool3d_backward(Tensor pts_idx_of_voxels, Tensor argmax,
Tensor grad_out, Tensor grad_in,
int pool_method) {
// params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel)
// params argmax: (N, out_x, out_y, out_z, C)
// params grad_out: (N, out_x, out_y, out_z, C)
// params grad_in: (npoints, C), return value
// params pool_method: 0: max_pool 1: avg_pool
if (grad_in.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(pts_idx_of_voxels);
CHECK_CUDA_INPUT(argmax);
CHECK_CUDA_INPUT(grad_out);
CHECK_CUDA_INPUT(grad_in);
int boxes_num = pts_idx_of_voxels.size(0);
int out_x = pts_idx_of_voxels.size(1);
int out_y = pts_idx_of_voxels.size(2);
int out_z = pts_idx_of_voxels.size(3);
int max_pts_each_voxel =
pts_idx_of_voxels.size(4); // index 0 is the counter
int channels = grad_out.size(4);
roiaware_pool3d_backward_cuda(boxes_num, out_x, out_y, out_z, channels,
max_pts_each_voxel, pts_idx_of_voxels, argmax,
grad_out, grad_in, pool_method);
#else
AT_ERROR("roiaware_pool3d is not compiled with GPU support");
#endif
} else {
AT_ERROR("roiaware_pool3d is not implemented on CPU");
}
}
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward',
'points_in_boxes_all_forward'
])
def points_in_boxes_part(points, boxes):
"""Find the box in which each point is (CUDA).
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in
LiDAR/DEPTH coordinate, (x, y, z) is the bottom center
Returns:
box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
"""
assert points.shape[0] == boxes.shape[0], \
'Points and boxes should have the same batch size, ' \
f'but got {points.shape[0]} and {boxes.shape[0]}'
assert boxes.shape[2] == 7, \
'boxes dimension should be 7, ' \
f'but got unexpected shape {boxes.shape[2]}'
assert points.shape[2] == 3, \
'points dimension should be 3, ' \
f'but got unexpected shape {points.shape[2]}'
batch_size, num_points, _ = points.shape
box_idxs_of_pts = points.new_zeros((batch_size, num_points),
dtype=torch.int).fill_(-1)
# If manually put the tensor 'points' or 'boxes' on a device
# which is not the current device, some temporary variables
# will be created on the current device in the cuda op,
# and the output will be incorrect.
# Therefore, we force the current device to be the same
# as the device of the tensors if it was not.
# Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
# for the incorrect output before the fix.
points_device = points.get_device()
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
ext_module.points_in_boxes_part_forward(boxes.contiguous(),
points.contiguous(),
box_idxs_of_pts)
return box_idxs_of_pts
def points_in_boxes_cpu(points, boxes):
"""Find all boxes in which each point is (CPU). The CPU version of
:meth:`points_in_boxes_all`.
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in
LiDAR/DEPTH coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
(x, y, z) is the bottom center.
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
"""
assert points.shape[0] == boxes.shape[0], \
'Points and boxes should have the same batch size, ' \
f'but got {points.shape[0]} and {boxes.shape[0]}'
assert boxes.shape[2] == 7, \
'boxes dimension should be 7, ' \
f'but got unexpected shape {boxes.shape[2]}'
assert points.shape[2] == 3, \
'points dimension should be 3, ' \
f'but got unexpected shape {points.shape[2]}'
batch_size, num_points, _ = points.shape
num_boxes = boxes.shape[1]
point_indices = points.new_zeros((batch_size, num_boxes, num_points),
dtype=torch.int)
for b in range(batch_size):
ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(),
points[b].float().contiguous(),
point_indices[b])
point_indices = point_indices.transpose(1, 2)
return point_indices
def points_in_boxes_all(points, boxes):
"""Find all boxes in which each point is (CUDA).
Args:
points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
boxes (torch.Tensor): [B, T, 7],
num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
(x, y, z) is the bottom center.
Returns:
box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
"""
assert boxes.shape[0] == points.shape[0], \
'Points and boxes should have the same batch size, ' \
f'but got {boxes.shape[0]} and {boxes.shape[0]}'
assert boxes.shape[2] == 7, \
'boxes dimension should be 7, ' \
f'but got unexpected shape {boxes.shape[2]}'
assert points.shape[2] == 3, \
'points dimension should be 3, ' \
f'but got unexpected shape {points.shape[2]}'
batch_size, num_points, _ = points.shape
num_boxes = boxes.shape[1]
box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
dtype=torch.int).fill_(0)
# Same reason as line 25-32
points_device = points.get_device()
assert points_device == boxes.get_device(), \
'Points and boxes should be put on the same device'
if torch.cuda.current_device() != points_device:
torch.cuda.set_device(points_device)
ext_module.points_in_boxes_all_forward(boxes.contiguous(),
points.contiguous(),
box_idxs_of_pts)
return box_idxs_of_pts
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn as nn
from torch.autograd import Function
import mmcv
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward'])
class RoIAwarePool3d(nn.Module):
"""Encode the geometry-specific features of each 3D proposal.
Please refer to `PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_ for more
details.
Args:
out_size (int or tuple): The size of output features. n or
[n1, n2, n3].
max_pts_per_voxel (int, optional): The maximum number of points per
voxel. Default: 128.
mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'.
Default: 'max'.
"""
def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
super().__init__()
self.out_size = out_size
self.max_pts_per_voxel = max_pts_per_voxel
assert mode in ['max', 'avg']
pool_mapping = {'max': 0, 'avg': 1}
self.mode = pool_mapping[mode]
def forward(self, rois, pts, pts_feature):
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois.
pts (torch.Tensor): [npoints, 3], coordinates of input points.
pts_feature (torch.Tensor): [npoints, C], features of input points.
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
"""
return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
self.out_size,
self.max_pts_per_voxel, self.mode)
class RoIAwarePool3dFunction(Function):
@staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
mode):
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois.
pts (torch.Tensor): [npoints, 3], coordinates of input points.
pts_feature (torch.Tensor): [npoints, C], features of input points.
out_size (int or tuple): The size of output features. n or
[n1, n2, n3].
max_pts_per_voxel (int): The maximum number of points per voxel.
Default: 128.
mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average
pool).
Returns:
pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output
pooled features.
"""
if isinstance(out_size, int):
out_x = out_y = out_z = out_size
else:
assert len(out_size) == 3
assert mmcv.is_tuple_of(out_size, int)
out_x, out_y, out_z = out_size
num_rois = rois.shape[0]
num_channels = pts_feature.shape[-1]
num_pts = pts.shape[0]
pooled_features = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, num_channels))
argmax = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
pts_idx_of_voxels = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, max_pts_per_voxel),
dtype=torch.int)
ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax,
pts_idx_of_voxels, pooled_features,
mode)
ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
num_pts, num_channels)
return pooled_features
@staticmethod
def backward(ctx, grad_out):
ret = ctx.roiaware_pool3d_for_backward
pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
grad_in = grad_out.new_zeros((num_pts, num_channels))
ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax,
grad_out.contiguous(), grad_in,
mode)
return None, None, grad_in, None, None, None
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu,
points_in_boxes_part)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_RoIAwarePool3d():
roiaware_pool3d_max = RoIAwarePool3d(
out_size=4, max_pts_per_voxel=128, mode='max')
roiaware_pool3d_avg = RoIAwarePool3d(
out_size=4, max_pts_per_voxel=128, mode='avg')
rois = torch.tensor(
[[1.0, 2.0, 3.0, 5.0, 4.0, 6.0, -0.3 - np.pi / 2],
[-10.0, 23.0, 16.0, 20.0, 10.0, 20.0, -0.5 - np.pi / 2]],
dtype=torch.float32).cuda(
) # boxes (m, 7) with bottom center in lidar coordinate
pts = torch.tensor(
[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6],
[0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3],
[4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9],
[-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]],
dtype=torch.float32).cuda() # points (n, 3) in lidar coordinate
pts_feature = pts.clone()
pooled_features_max = roiaware_pool3d_max(
rois=rois, pts=pts, pts_feature=pts_feature)
assert pooled_features_max.shape == torch.Size([2, 4, 4, 4, 3])
assert torch.allclose(pooled_features_max.sum(),
torch.tensor(51.100).cuda(), 1e-3)
pooled_features_avg = roiaware_pool3d_avg(
rois=rois, pts=pts, pts_feature=pts_feature)
assert pooled_features_avg.shape == torch.Size([2, 4, 4, 4, 3])
assert torch.allclose(pooled_features_avg.sum(),
torch.tensor(49.750).cuda(), 1e-3)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_points_in_boxes_part():
boxes = torch.tensor(
[[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3]],
[[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]],
dtype=torch.float32).cuda(
) # boxes (b, t, 7) with bottom center in lidar coordinate
pts = torch.tensor(
[[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6],
[0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3],
[4.7, 3.5, -12.2]],
[[3.8, 7.6, -2], [-10.6, -12.9, -20], [-16, -18, 9], [-21.3, -52, -5],
[0, 0, 0], [6, 7, 8], [-2, -3, -4], [6, 4, 9]]],
dtype=torch.float32).cuda() # points (b, m, 3) in lidar coordinate
point_indices = points_in_boxes_part(points=pts, boxes=boxes)
expected_point_indices = torch.tensor(
[[0, 0, 0, 0, 0, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1]],
dtype=torch.int32).cuda()
assert point_indices.shape == torch.Size([2, 8])
assert (point_indices == expected_point_indices).all()
boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]],
dtype=torch.float32).cuda() # 30 degrees
pts = torch.tensor(
[[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0],
[-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]],
dtype=torch.float32).cuda()
point_indices = points_in_boxes_part(points=pts, boxes=boxes)
expected_point_indices = torch.tensor([[-1, -1, 0, -1, 0, -1, -1, -1]],
dtype=torch.int32).cuda()
assert (point_indices == expected_point_indices).all()
def test_points_in_boxes_cpu():
boxes = torch.tensor(
[[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3],
[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]],
dtype=torch.float32
) # boxes (m, 7) with bottom center in lidar coordinate
pts = torch.tensor(
[[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6],
[0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3],
[4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [
-16, -18, 9
], [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]]],
dtype=torch.float32) # points (n, 3) in lidar coordinate
point_indices = points_in_boxes_cpu(points=pts, boxes=boxes)
expected_point_indices = torch.tensor(
[[[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [0, 1], [0, 0], [0, 0],
[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]],
dtype=torch.int32)
assert point_indices.shape == torch.Size([1, 15, 2])
assert (point_indices == expected_point_indices).all()
boxes = torch.tensor([[[0.0, 0.0, 0.0, 1.0, 20.0, 1.0, 0.523598]]],
dtype=torch.float32) # 30 degrees
pts = torch.tensor(
[[[4, 6.928, 0], [6.928, 4, 0], [4, -6.928, 0], [6.928, -4, 0],
[-4, 6.928, 0], [-6.928, 4, 0], [-4, -6.928, 0], [-6.928, -4, 0]]],
dtype=torch.float32)
point_indices = points_in_boxes_cpu(points=pts, boxes=boxes)
expected_point_indices = torch.tensor(
[[[0], [0], [1], [0], [1], [0], [0], [0]]], dtype=torch.int32)
assert (point_indices == expected_point_indices).all()
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_points_in_boxes_all():
boxes = torch.tensor(
[[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.3],
[-10.0, 23.0, 16.0, 10, 20, 20, 0.5]]],
dtype=torch.float32).cuda(
) # boxes (m, 7) with bottom center in lidar coordinate
pts = torch.tensor(
[[[1, 2, 3.3], [1.2, 2.5, 3.0], [0.8, 2.1, 3.5], [1.6, 2.6, 3.6],
[0.8, 1.2, 3.9], [-9.2, 21.0, 18.2], [3.8, 7.9, 6.3],
[4.7, 3.5, -12.2], [3.8, 7.6, -2], [-10.6, -12.9, -20], [
-16, -18, 9
], [-21.3, -52, -5], [0, 0, 0], [6, 7, 8], [-2, -3, -4]]],
dtype=torch.float32).cuda() # points (n, 3) in lidar coordinate
point_indices = points_in_boxes_all(points=pts, boxes=boxes)
expected_point_indices = torch.tensor(
[[[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [0, 1], [0, 0], [0, 0],
[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]],
dtype=torch.int32).cuda()
assert point_indices.shape == torch.Size([1, 15, 2])
assert (point_indices == expected_point_indices).all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment