Unverified Commit e3e1dba2 authored by dingchang's avatar dingchang Committed by GitHub
Browse files

[Feature] Add voxel ops from mmdet3d (#1381)

* add ops (voxel) in mmdet3d

* add ops (voxel) in mmdet3d

* add ops (voxel) in mmdet3d

* refactor code

* update test

* update test

* update test

* refactor code

* refactor code

* refactor code

* fix typo

* fix typo
parent e8489a7b
......@@ -11,6 +11,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- DynamicScatter
- GatherPoints
- FurthestPointSample
- FurthestPointSampleWithDist
......@@ -27,6 +28,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- SoftmaxFocalLoss
- SoftNMS
- Synchronized BatchNorm
- Voxelization
- ThreeInterpolate
- ThreeNN
- Weight standardization
......
......@@ -11,6 +11,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- DynamicScatter
- GatherPoints
- FurthestPointSample
- FurthestPointSampleWithDist
......@@ -27,6 +28,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- SoftmaxFocalLoss
- SoftNMS
- Synchronized BatchNorm
- Voxelization
- ThreeInterpolate
- ThreeNN
- Weight standardization
......
......@@ -41,11 +41,13 @@ from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_pool import RoIPool, roi_pool
from .roipoint_pool3d import RoIPointPool3d
from .saconv import SAConv2d
from .scatter_points import DynamicScatter, dynamic_scatter
from .sync_bn import SyncBatchNorm
from .three_interpolate import three_interpolate
from .three_nn import three_nn
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d
from .voxelize import Voxelization, voxelization
__all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
......@@ -65,6 +67,7 @@ __all__ = [
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention',
'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter',
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
]
// Copyright (c) OpenMMLab. All rights reserved
#ifndef SCATTER_POINTS_CUDA_KERNEL_CUH
#define SCATTER_POINTS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
int const maxGridDim = 50000;
__device__ __forceinline__ static void reduceMax(float *address, float val) {
int *address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
old = atomicCAS(address_as_i, assumed,
__float_as_int(fmaxf(val, __int_as_float(assumed))));
} while (assumed != old || __int_as_float(old) < val);
}
__device__ __forceinline__ static void reduceMax(double *address, double val) {
unsigned long long *address_as_ull =
reinterpret_cast<unsigned long long *>(address);
unsigned long long old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(
address_as_ull, assumed,
__double_as_longlong(fmax(val, __longlong_as_double(assumed))));
} while (assumed != old || __longlong_as_double(old) < val);
}
// get rid of meaningless warnings when compiling host code
#ifdef __CUDA_ARCH__
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
#if (__CUDA_ARCH__ < 200)
#warning \
"compute capability lower than 2.x. fall back to use CAS version of atomicAdd for float32"
int *address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
old = atomicCAS(address_as_i, assumed,
__float_as_int(val + __int_as_float(assumed)));
} while (assumed != old);
#else
atomicAdd(address, val);
#endif
}
__device__ __forceinline__ static void reduceAdd(double *address, double val) {
#if (__CUDA_ARCH__ < 600)
#warning \
"compute capability lower than 6.x. fall back to use CAS version of atomicAdd for float64"
unsigned long long *address_as_ull =
reinterpret_cast<unsigned long long *>(address);
unsigned long long old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
#else
atomicAdd(address, val);
#endif
}
#endif
template <typename T>
__global__ void feats_reduce_kernel(
const T *feats, const int32_t *coors_map,
T *reduced_feats, // shall be 0 at initialization
const int num_input, const int num_feats, const reduce_t reduce_type) {
CUDA_1D_KERNEL_LOOP(x, num_input) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1) continue;
const T *feats_offset = feats + x * num_feats;
T *reduced_feats_offset = reduced_feats + reduce_to * num_feats;
if (reduce_type == reduce_t::MAX) {
for (int i = 0; i < num_feats; i++) {
reduceMax(&reduced_feats_offset[i], feats_offset[i]);
}
} else {
for (int i = 0; i < num_feats; i++) {
reduceAdd(&reduced_feats_offset[i], feats_offset[i]);
}
}
}
}
template <typename T>
__global__ void add_reduce_traceback_grad_kernel(
T *grad_feats, const T *grad_reduced_feats, const int32_t *coors_map,
const int32_t *reduce_count, const int num_input, const int num_feats,
const reduce_t reduce_type) {
CUDA_1D_KERNEL_LOOP(x, num_input) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1) {
continue;
}
const int input_offset = x * num_feats;
T *grad_feats_offset = grad_feats + input_offset;
const int reduced_offset = reduce_to * num_feats;
const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
if (reduce_type == reduce_t::SUM) {
for (int i = 0; i < num_feats; i++) {
grad_feats_offset[i] = grad_reduced_feats_offset[i];
}
} else if (reduce_type == reduce_t::MEAN) {
for (int i = 0; i < num_feats; i++) {
grad_feats_offset[i] = grad_reduced_feats_offset[i] /
static_cast<T>(reduce_count[reduce_to]);
}
}
}
}
template <typename T>
__global__ void max_reduce_traceback_scatter_idx_kernel(
const T *feats, const T *reduced_feats, int32_t *reduce_from,
const int32_t *coors_map, const int num_input, const int num_feats) {
CUDA_1D_KERNEL_LOOP(x, num_input) {
int32_t reduce_to = coors_map[x];
const int input_offset = x * num_feats;
const T *feats_offset = feats + input_offset;
if (reduce_to == -1) {
continue;
}
const int reduced_offset = reduce_to * num_feats;
const T *reduced_feats_offset = reduced_feats + reduced_offset;
int32_t *reduce_from_offset = reduce_from + reduced_offset;
for (int i = 0; i < num_feats; i++) {
if (feats_offset[i] == reduced_feats_offset[i]) {
atomicMin(&reduce_from_offset[i], static_cast<int32_t>(x));
}
}
}
}
template <typename T>
__global__ void max_reduce_scatter_grad_kernel(T *grad_feats,
const T *grad_reduced_feats,
const int32_t *reduce_from,
const int num_reduced,
const int num_feats) {
CUDA_1D_KERNEL_LOOP(x, num_reduced) {
const int reduced_offset = x * num_feats;
const int32_t *scatter_to_offset = reduce_from + reduced_offset;
const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
for (int i = 0; i < num_feats; i++) {
grad_feats[scatter_to_offset[i] * num_feats + i] =
grad_reduced_feats_offset[i];
}
}
}
#endif // SCATTER_POINTS_CUDA_KERNEL_CUH
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef VOXELIZATION_CUDA_KERNEL_CUH
#define VOXELIZATION_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
template <typename T, typename T_int>
__global__ void dynamic_voxelize_kernel(
const T* points, T_int* coors, const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min, const float coors_y_min,
const float coors_z_min, const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int grid_x, const int grid_y,
const int grid_z, const int num_points, const int num_features,
const int NDim) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, num_points) {
// To save some computation
auto points_offset = points + index * num_features;
auto coors_offset = coors + index * NDim;
int c_x = floor((points_offset[0] - coors_x_min) / voxel_x);
if (c_x < 0 || c_x >= grid_x) {
coors_offset[0] = -1;
continue;
}
int c_y = floor((points_offset[1] - coors_y_min) / voxel_y);
if (c_y < 0 || c_y >= grid_y) {
coors_offset[0] = -1;
coors_offset[1] = -1;
continue;
}
int c_z = floor((points_offset[2] - coors_z_min) / voxel_z);
if (c_z < 0 || c_z >= grid_z) {
coors_offset[0] = -1;
coors_offset[1] = -1;
coors_offset[2] = -1;
} else {
coors_offset[0] = c_z;
coors_offset[1] = c_y;
coors_offset[2] = c_x;
}
}
}
template <typename T, typename T_int>
__global__ void assign_point_to_voxel(const int nthreads, const T* points,
T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T* voxels,
const int max_points,
const int num_features,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
int index = thread_idx / num_features;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num > -1 && voxelidx > -1) {
auto voxels_offset =
voxels + voxelidx * max_points * num_features + num * num_features;
int k = thread_idx % num_features;
voxels_offset[k] = points[thread_idx];
}
}
}
template <typename T, typename T_int>
__global__ void assign_voxel_coors(const int nthreads, T_int* coor,
T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T_int* voxel_coors,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
// if (index >= num_points) return;
int index = thread_idx / NDim;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num == 0 && voxelidx > -1) {
auto coors_offset = voxel_coors + voxelidx * NDim;
int k = thread_idx % NDim;
coors_offset[k] = coor[thread_idx];
}
}
}
template <typename T_int>
__global__ void point_to_voxelidx_kernel(const T_int* coor,
T_int* point_to_voxelidx,
T_int* point_to_pointidx,
const int max_points,
const int max_voxels,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(index, num_points) {
auto coor_offset = coor + index * NDim;
// skip invalid points
if ((index >= num_points) || (coor_offset[0] == -1)) return;
int num = 0;
int coor_x = coor_offset[0];
int coor_y = coor_offset[1];
int coor_z = coor_offset[2];
// only calculate the coors before this coor[index]
for (int i = 0; i < index; ++i) {
auto prev_coor = coor + i * NDim;
if (prev_coor[0] == -1) continue;
// Find all previous points that have the same coors
// if find the same coor, record it
if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) &&
(prev_coor[2] == coor_z)) {
num++;
if (num == 1) {
// point to the same coor that first show up
point_to_pointidx[index] = i;
} else if (num >= max_points) {
// out of boundary
return;
}
}
}
if (num == 0) {
point_to_pointidx[index] = index;
}
if (num < max_points) {
point_to_voxelidx[index] = num;
}
}
}
template <typename T_int>
__global__ void determin_voxel_num(
// const T_int* coor,
T_int* num_points_per_voxel, T_int* point_to_voxelidx,
T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num,
const int max_points, const int max_voxels, const int num_points) {
// only calculate the coors before this coor[index]
for (int i = 0; i < num_points; ++i) {
int point_pos_in_voxel = point_to_voxelidx[i];
// record voxel
if (point_pos_in_voxel == -1) {
// out of max_points or invalid point
continue;
} else if (point_pos_in_voxel == 0) {
// record new voxel
int voxelidx = voxel_num[0];
if (voxel_num[0] >= max_voxels) continue;
voxel_num[0] += 1;
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] = 1;
} else {
int point_idx = point_to_pointidx[i];
int voxelidx = coor_to_voxelidx[point_idx];
if (voxelidx != -1) {
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] += 1;
}
}
}
}
#endif // VOXELIZATION_CUDA_KERNEL_CUH
// Copyright (c) OpenMMLab. All rights reserved.
#include <stdio.h>
#include <stdlib.h>
#include <torch/types.h>
#include "pytorch_cuda_helper.hpp"
#include "scatter_points_cuda_kernel.cuh"
std::vector<at::Tensor> DynamicPointToVoxelForwardCUDAKernelLauncher(
const at::Tensor &feats, const at::Tensor &coors,
const reduce_t reduce_type) {
const int num_input = feats.size(0);
const int num_feats = feats.size(1);
if (num_input == 0)
return {feats.clone().detach(), coors.clone().detach(),
coors.new_empty({0}, torch::kInt32),
coors.new_empty({0}, torch::kInt32)};
at::Tensor out_coors;
at::Tensor coors_map;
at::Tensor reduce_count;
auto coors_clean = coors.masked_fill(coors.lt(0).any(-1, true), -1);
std::tie(out_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, true);
// the first element of out_coors is always (-1,-1,-1) and should be removed
out_coors = out_coors.slice(0, 1);
reduce_count = reduce_count.slice(0, 1).to(torch::kInt32);
coors_map = coors_map.to(torch::kInt32) - 1;
auto reduced_feats =
at::empty({out_coors.size(0), num_feats}, feats.options());
at::cuda::CUDAGuard device_guard(feats.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(
feats.scalar_type(), "feats_reduce_kernel", ([&] {
if (reduce_type == reduce_t::MAX)
reduced_feats.fill_(-std::numeric_limits<scalar_t>::infinity());
else
reduced_feats.fill_(static_cast<scalar_t>(0));
dim3 blocks(std::min(
at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
dim3 threads(THREADS_PER_BLOCK);
feats_reduce_kernel<<<blocks, threads, 0, stream>>>(
feats.data_ptr<scalar_t>(), coors_map.data_ptr<int32_t>(),
reduced_feats.data_ptr<scalar_t>(), num_input, num_feats,
reduce_type);
if (reduce_type == reduce_t::MEAN)
reduced_feats /= reduce_count.unsqueeze(-1).to(reduced_feats.dtype());
}));
AT_CUDA_CHECK(cudaGetLastError());
return {reduced_feats, out_coors, coors_map, reduce_count};
}
void DynamicPointToVoxelBackwardCUDAKernelLauncher(
at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
const at::Tensor &feats, const at::Tensor &reduced_feats,
const at::Tensor &coors_map, const at::Tensor &reduce_count,
const reduce_t reduce_type) {
const int num_input = feats.size(0);
const int num_reduced = reduced_feats.size(0);
const int num_feats = feats.size(1);
grad_feats.fill_(0);
// copy voxel grad to points
if (num_input == 0 || num_reduced == 0) return;
at::cuda::CUDAGuard device_guard(feats.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) {
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel",
([&] {
dim3 blocks(std::min(
at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
dim3 threads(THREADS_PER_BLOCK);
add_reduce_traceback_grad_kernel<<<blocks, threads, 0, stream>>>(
grad_feats.data_ptr<scalar_t>(),
grad_reduced_feats.data_ptr<scalar_t>(),
coors_map.data_ptr<int32_t>(), reduce_count.data_ptr<int32_t>(),
num_input, num_feats, reduce_type);
}));
AT_CUDA_CHECK(cudaGetLastError());
} else {
auto reduce_from = at::full({num_reduced, num_feats}, num_input,
coors_map.options().dtype(torch::kInt32));
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(),
"max_reduce_traceback_scatter_idx_kernel", ([&] {
dim3 blocks(std::min(
at::cuda::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim));
dim3 threads(THREADS_PER_BLOCK);
max_reduce_traceback_scatter_idx_kernel<<<blocks, threads, 0,
stream>>>(
feats.data_ptr<scalar_t>(), reduced_feats.data_ptr<scalar_t>(),
reduce_from.data_ptr<int32_t>(), coors_map.data_ptr<int32_t>(),
num_input, num_feats);
}));
AT_CUDA_CHECK(cudaGetLastError());
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(),
"max_reduce_traceback_scatter_idx_kernel", ([&] {
dim3 blocks(
std::min(at::cuda::ATenCeilDiv(num_reduced, THREADS_PER_BLOCK),
maxGridDim));
dim3 threads(THREADS_PER_BLOCK);
max_reduce_scatter_grad_kernel<<<blocks, threads, 0, stream>>>(
grad_feats.data_ptr<scalar_t>(),
grad_reduced_feats.data_ptr<scalar_t>(),
reduce_from.data_ptr<int32_t>(), num_reduced, num_feats);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
}
// Copyright (c) OpenMMLab. All rights reserved.
#include <stdio.h>
#include <stdlib.h>
#include "pytorch_cuda_helper.hpp"
#include "voxelization_cuda_kernel.cuh"
int HardVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
// current version tooks about 0.04s for one frame on cpu
// check device
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int num_points = points.size(0);
const int num_features = points.size(1);
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
dim3 grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 block(512);
// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int><<<grid, block, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
// 2. map point to the idx of the corresponding voxel, find duplicate coor
// create some temporary variables
auto point_to_pointidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
auto point_to_voxelidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
dim3 map_grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 map_block(512);
AT_DISPATCH_ALL_TYPES(
temp_coors.scalar_type(), "determin_duplicate", ([&] {
point_to_voxelidx_kernel<int><<<map_grid, map_block, 0, stream>>>(
temp_coors.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
point_to_pointidx.contiguous().data_ptr<int>(), max_points,
max_voxels, num_points, NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
// 3. determine voxel num and voxel's coor index
// make the logic in the CUDA device could accelerate about 10 times
auto coor_to_voxelidx = -at::ones(
{
num_points,
},
points.options().dtype(at::kInt));
auto voxel_num = at::zeros(
{
1,
},
points.options().dtype(at::kInt)); // must be zero from the beginning
AT_DISPATCH_ALL_TYPES(temp_coors.scalar_type(), "determin_duplicate", ([&] {
determin_voxel_num<int><<<1, 1, 0, stream>>>(
num_points_per_voxel.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
point_to_pointidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
voxel_num.contiguous().data_ptr<int>(),
max_points, max_voxels, num_points);
}));
AT_CUDA_CHECK(cudaGetLastError());
// 4. copy point features to voxels
// Step 4 & 5 could be parallel
auto pts_output_size = num_points * num_features;
dim3 cp_grid(std::min(at::cuda::ATenCeilDiv(pts_output_size, 512), 4096));
dim3 cp_block(512);
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
assign_point_to_voxel<float, int><<<cp_grid, cp_block, 0, stream>>>(
pts_output_size, points.contiguous().data_ptr<float>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
voxels.contiguous().data_ptr<float>(), max_points, num_features,
num_points, NDim);
}));
// cudaDeviceSynchronize();
// AT_CUDA_CHECK(cudaGetLastError());
// 5. copy coors of each voxels
auto coors_output_size = num_points * NDim;
dim3 coors_cp_grid(
std::min(at::cuda::ATenCeilDiv(coors_output_size, 512), 4096));
dim3 coors_cp_block(512);
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
assign_voxel_coors<float, int>
<<<coors_cp_grid, coors_cp_block, 0, stream>>>(
coors_output_size, temp_coors.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(),
coor_to_voxelidx.contiguous().data_ptr<int>(),
coors.contiguous().data_ptr<int>(), num_points, NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
auto voxel_num_cpu = voxel_num.to(at::kCPU);
int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
return voxel_num_int;
}
void DynamicVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
const int NDim = 3) {
// current version tooks about 0.04s for one frame on cpu
// check device
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int num_points = points.size(0);
const int num_features = points.size(1);
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
const int col_blocks = at::cuda::ATenCeilDiv(num_points, THREADS_PER_BLOCK);
dim3 blocks(col_blocks);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_ALL_TYPES(points.scalar_type(), "dynamic_voxelize_kernel", [&] {
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, NDim);
});
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -264,6 +264,30 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
int pooled_width, float spatial_scale,
int sample_num, bool aligned, bool clockwise);
std::vector<torch::Tensor> dynamic_point_to_voxel_forward(
const torch::Tensor &feats, const torch::Tensor &coors,
const std::string &reduce_type);
void dynamic_point_to_voxel_backward(torch::Tensor &grad_feats,
const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats,
const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx,
const torch::Tensor &reduce_count,
const std::string &reduce_type);
int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim);
void dynamic_voxelize_forward(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim);
void border_align_forward(const Tensor &input, const Tensor &boxes,
Tensor output, Tensor argmax_idx,
const int pool_size);
......@@ -540,6 +564,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_input"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sample_num"), py::arg("aligned"), py::arg("clockwise"));
m.def("dynamic_point_to_voxel_forward", &dynamic_point_to_voxel_forward,
"dynamic_point_to_voxel_forward", py::arg("feats"), py::arg("coors"),
py::arg("reduce_type"));
m.def("dynamic_point_to_voxel_backward", &dynamic_point_to_voxel_backward,
"dynamic_point_to_voxel_backward", py::arg("grad_feats"),
py::arg("grad_reduced_feats"), py::arg("feats"),
py::arg("reduced_feats"), py::arg("coors_idx"), py::arg("reduce_count"),
py::arg("reduce_type"));
m.def("hard_voxelize_forward", &hard_voxelize_forward,
"hard_voxelize_forward", py::arg("points"), py::arg("voxels"),
py::arg("coors"), py::arg("num_points_per_voxel"),
py::arg("voxel_size"), py::arg("coors_range"), py::arg("max_points"),
py::arg("max_voxels"), py::arg("NDim"));
m.def("dynamic_voxelize_forward", &dynamic_voxelize_forward,
"dynamic_voxelize_forward", py::arg("points"), py::arg("coors"),
py::arg("voxel_size"), py::arg("coors_range"), py::arg("NDim"));
m.def("ms_deform_attn_forward", &ms_deform_attn_forward,
"forward function of multi-scale deformable attention",
py::arg("value"), py::arg("value_spatial_shapes"),
......
// Copyright (c) OpenMMLab. All rights reserved.
#include "pytorch_cpp_helper.hpp"
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
#ifdef MMCV_WITH_CUDA
std::vector<torch::Tensor> DynamicPointToVoxelForwardCUDAKernelLauncher(
const torch::Tensor &feats, const torch::Tensor &coors,
const reduce_t reduce_type);
std::vector<torch::Tensor> dynamic_point_to_voxel_forward_cuda(
const torch::Tensor &feats, const torch::Tensor &coors,
const reduce_t reduce_type) {
return DynamicPointToVoxelForwardCUDAKernelLauncher(feats, coors,
reduce_type);
};
void DynamicPointToVoxelBackwardCUDAKernelLauncher(
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_cuda(
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx, const torch::Tensor &reduce_count,
const reduce_t reduce_type) {
DynamicPointToVoxelBackwardCUDAKernelLauncher(grad_feats, grad_reduced_feats,
feats, reduced_feats, coors_idx,
reduce_count, reduce_type);
};
#endif
std::vector<at::Tensor> dynamic_point_to_voxel_forward_cpu(
const at::Tensor &points, const at::Tensor &voxel_mapping,
const std::vector<float> voxel_size, const std::vector<float> coors_range);
inline reduce_t convert_reduce_type(const std::string &reduce_type) {
if (reduce_type == "max")
return reduce_t::MAX;
else if (reduce_type == "sum")
return reduce_t::SUM;
else if (reduce_type == "mean")
return reduce_t::MEAN;
else
TORCH_CHECK(false, "do not support reduce type " + reduce_type)
return reduce_t::SUM;
}
std::vector<torch::Tensor> dynamic_point_to_voxel_forward(
const torch::Tensor &feats, const torch::Tensor &coors,
const std::string &reduce_type) {
if (feats.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(feats);
CHECK_CUDA_INPUT(coors);
return dynamic_point_to_voxel_forward_cuda(
feats, coors, convert_reduce_type(reduce_type));
#else
AT_ERROR("dynamic_point_to_voxel is not compiled with GPU support");
#endif
} else {
AT_ERROR("dynamic_point_to_voxel is not implemented on CPU");
return std::vector<torch::Tensor>();
}
}
void dynamic_point_to_voxel_backward(torch::Tensor &grad_feats,
const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats,
const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx,
const torch::Tensor &reduce_count,
const std::string &reduce_type) {
if (grad_feats.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_feats);
CHECK_CUDA_INPUT(grad_reduced_feats);
CHECK_CUDA_INPUT(feats);
CHECK_CUDA_INPUT(reduced_feats);
CHECK_CUDA_INPUT(coors_idx);
CHECK_CUDA_INPUT(reduce_count);
dynamic_point_to_voxel_backward_cuda(grad_feats, grad_reduced_feats, feats,
reduced_feats, coors_idx, reduce_count,
convert_reduce_type(reduce_type));
#else
AT_ERROR("dynamic_point_to_voxel is not compiled with GPU support");
#endif
} else {
AT_ERROR("dynamic_point_to_voxel is not implemented on CPU");
}
}
// Copyright (c) OpenMMLab. All rights reserved.
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
int HardVoxelizeForwardCUDAKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);
int hard_voxelize_forward_cuda(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
return HardVoxelizeForwardCUDAKernelLauncher(
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
max_points, max_voxels, NDim);
};
void DynamicVoxelizeForwardCUDAKernelLauncher(
const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
const int NDim = 3);
void dynamic_voxelize_forward_cuda(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3) {
DynamicVoxelizeForwardCUDAKernelLauncher(points, coors, voxel_size,
coors_range, NDim);
};
#endif
int hard_voxelize_forward_cpu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors,
at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3);
void dynamic_voxelize_forward_cpu(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3);
int hard_voxelize_forward(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
if (points.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(points);
return hard_voxelize_forward_cuda(
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
max_points, max_voxels, NDim);
#else
AT_ERROR("hard_voxelize is not compiled with GPU support");
#endif
} else {
return hard_voxelize_forward_cpu(points, voxels, coors,
num_points_per_voxel, voxel_size,
coors_range, max_points, max_voxels, NDim);
}
}
void dynamic_voxelize_forward(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3) {
if (points.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(points);
dynamic_voxelize_forward_cuda(points, coors, voxel_size, coors_range, NDim);
#else
AT_ERROR("dynamic_voxelize is not compiled with GPU support");
#endif
} else {
dynamic_voxelize_forward_cpu(points, coors, voxel_size, coors_range, NDim);
}
}
// Copyright (c) OpenMMLab. All rights reserved.
#include "pytorch_cpp_helper.hpp"
template <typename T, typename T_int>
void dynamic_voxelize_forward_cpu_kernel(
const torch::TensorAccessor<T, 2> points,
torch::TensorAccessor<T_int, 2> coors, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<int> grid_size,
const int num_points, const int num_features, const int NDim) {
const int ndim_minus_1 = NDim - 1;
bool failed = false;
// int coor[NDim];
int* coor = new int[NDim]();
int c;
for (int i = 0; i < num_points; ++i) {
failed = false;
for (int j = 0; j < NDim; ++j) {
c = floor((points[i][j] - coors_range[j]) / voxel_size[j]);
// necessary to rm points out of range
if ((c < 0 || c >= grid_size[j])) {
failed = true;
break;
}
coor[ndim_minus_1 - j] = c;
}
if (failed)
memset(&coors[i][0], -1, NDim * sizeof(T_int));
else
memcpy(&coors[i][0], &coor[0], NDim * sizeof(T_int));
}
delete[] coor;
}
template <typename T, typename T_int>
void hard_voxelize_forward_cpu_kernel(
const torch::TensorAccessor<T, 2> points,
torch::TensorAccessor<T, 3> voxels, torch::TensorAccessor<T_int, 2> coors,
torch::TensorAccessor<T_int, 1> num_points_per_voxel,
torch::TensorAccessor<T_int, 3> coor_to_voxelidx, int& voxel_num,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
const std::vector<int> grid_size, const int max_points,
const int max_voxels, const int num_points, const int num_features,
const int NDim) {
// declare a temp coors
at::Tensor temp_coors = at::zeros(
{num_points, NDim}, at::TensorOptions().dtype(at::kInt).device(at::kCPU));
// First use dynamic voxelization to get coors,
// then check max points/voxels constraints
dynamic_voxelize_forward_cpu_kernel<T, int>(
points, temp_coors.accessor<int, 2>(), voxel_size, coors_range, grid_size,
num_points, num_features, NDim);
int voxelidx, num;
auto coor = temp_coors.accessor<int, 2>();
for (int i = 0; i < num_points; ++i) {
// T_int* coor = temp_coors.data_ptr<int>() + i * NDim;
if (coor[i][0] == -1) continue;
voxelidx = coor_to_voxelidx[coor[i][0]][coor[i][1]][coor[i][2]];
// record voxel
if (voxelidx == -1) {
voxelidx = voxel_num;
if (max_voxels != -1 && voxel_num >= max_voxels) continue;
voxel_num += 1;
coor_to_voxelidx[coor[i][0]][coor[i][1]][coor[i][2]] = voxelidx;
memcpy(&coors[voxelidx][0], &coor[i][0], NDim * sizeof(T_int));
}
// put points into voxel
num = num_points_per_voxel[voxelidx];
if (max_points == -1 || num < max_points) {
memcpy(&voxels[voxelidx][num][0], &points[i][0],
num_features * sizeof(T));
num_points_per_voxel[voxelidx] += 1;
}
}
return;
}
void dynamic_voxelize_forward_cpu(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3) {
// check device
AT_ASSERTM(points.device().is_cpu(), "points must be a CPU tensor");
std::vector<int> grid_size(NDim);
const int num_points = points.size(0);
const int num_features = points.size(1);
for (int i = 0; i < NDim; ++i) {
grid_size[i] =
round((coors_range[NDim + i] - coors_range[i]) / voxel_size[i]);
}
// coors, num_points_per_voxel, coor_to_voxelidx are int Tensor
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "dynamic_voxelize_forward_cpu_kernel", [&] {
dynamic_voxelize_forward_cpu_kernel<scalar_t, int>(
points.accessor<scalar_t, 2>(), coors.accessor<int, 2>(),
voxel_size, coors_range, grid_size, num_points, num_features, NDim);
});
}
int hard_voxelize_forward_cpu(const at::Tensor& points, at::Tensor& voxels,
at::Tensor& coors,
at::Tensor& num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
// current version tooks about 0.02s_0.03s for one frame on cpu
// check device
AT_ASSERTM(points.device().is_cpu(), "points must be a CPU tensor");
std::vector<int> grid_size(NDim);
const int num_points = points.size(0);
const int num_features = points.size(1);
for (int i = 0; i < NDim; ++i) {
grid_size[i] =
round((coors_range[NDim + i] - coors_range[i]) / voxel_size[i]);
}
// coors, num_points_per_voxel, coor_to_voxelidx are int Tensor
// printf("cpu coor_to_voxelidx size: [%d, %d, %d]\n", grid_size[2],
// grid_size[1], grid_size[0]);
at::Tensor coor_to_voxelidx =
-at::ones({grid_size[2], grid_size[1], grid_size[0]}, coors.options());
int voxel_num = 0;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "hard_voxelize_forward_cpu_kernel", [&] {
hard_voxelize_forward_cpu_kernel<scalar_t, int>(
points.accessor<scalar_t, 2>(), voxels.accessor<scalar_t, 3>(),
coors.accessor<int, 2>(), num_points_per_voxel.accessor<int, 1>(),
coor_to_voxelidx.accessor<int, 3>(), voxel_num, voxel_size,
coors_range, grid_size, max_points, max_voxels, num_points,
num_features, NDim);
});
return voxel_num;
}
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext',
['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward'])
class _DynamicScatter(Function):
@staticmethod
def forward(ctx, feats, coors, reduce_type='max'):
"""convert kitti points(N, >=3) to voxels.
Args:
feats (torch.Tensor): [N, C]. Points features to be reduced
into voxels.
coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates
(specifically multi-dim voxel index) of each points.
reduce_type (str, optional): Reduce op. support 'max', 'sum' and
'mean'. Default: 'max'.
Returns:
voxel_feats (torch.Tensor): [M, C]. Reduced features, input
features that shares the same voxel coordinates are reduced to
one row.
voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates.
"""
results = ext_module.dynamic_point_to_voxel_forward(
feats, coors, reduce_type)
(voxel_feats, voxel_coors, point2voxel_map,
voxel_points_count) = results
ctx.reduce_type = reduce_type
ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
voxel_points_count)
ctx.mark_non_differentiable(voxel_coors)
return voxel_feats, voxel_coors
@staticmethod
def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
(feats, voxel_feats, point2voxel_map,
voxel_points_count) = ctx.saved_tensors
grad_feats = torch.zeros_like(feats)
# TODO: whether to use index put or use cuda_backward
# To use index put, need point to voxel index
ext_module.dynamic_point_to_voxel_backward(
grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats,
point2voxel_map, voxel_points_count, ctx.reduce_type)
return grad_feats, None, None
dynamic_scatter = _DynamicScatter.apply
class DynamicScatter(nn.Module):
"""Scatters points into voxels, used in the voxel encoder with dynamic
voxelization.
Note:
The CPU and GPU implementation get the same output, but have numerical
difference after summation and division (e.g., 5e-7).
Args:
voxel_size (list): list [x, y, z] size of three dimension.
point_cloud_range (list): The coordinate range of points, [x_min,
y_min, z_min, x_max, y_max, z_max].
average_points (bool): whether to use avg pooling to scatter points
into voxel.
"""
def __init__(self, voxel_size, point_cloud_range, average_points: bool):
super().__init__()
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.average_points = average_points
def forward_single(self, points, coors):
"""Scatters points into voxels.
Args:
points (torch.Tensor): Points to be reduced into voxels.
coors (torch.Tensor): Corresponding voxel coordinates (specifically
multi-dim voxel index) of each points.
Returns:
voxel_feats (torch.Tensor): Reduced features, input features that
shares the same voxel coordinates are reduced to one row.
voxel_coors (torch.Tensor): Voxel coordinates.
"""
reduce = 'mean' if self.average_points else 'max'
return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
def forward(self, points, coors):
"""Scatters points/features into voxels.
Args:
points (torch.Tensor): Points to be reduced into voxels.
coors (torch.Tensor): Corresponding voxel coordinates (specifically
multi-dim voxel index) of each points.
Returns:
voxel_feats (torch.Tensor): Reduced features, input features that
shares the same voxel coordinates are reduced to one row.
voxel_coors (torch.Tensor): Voxel coordinates.
"""
if coors.size(-1) == 3:
return self.forward_single(points, coors)
else:
batch_size = coors[-1, 0] + 1
voxels, voxel_coors = [], []
for i in range(batch_size):
inds = torch.where(coors[:, 0] == i)
voxel, voxel_coor = self.forward_single(
points[inds], coors[inds][:, 1:])
coor_pad = nn.functional.pad(
voxel_coor, (1, 0), mode='constant', value=i)
voxel_coors.append(coor_pad)
voxels.append(voxel)
features = torch.cat(voxels, dim=0)
feature_coors = torch.cat(voxel_coors, dim=0)
return features, feature_coors
def __repr__(self):
s = self.__class__.__name__ + '('
s += 'voxel_size=' + str(self.voxel_size)
s += ', point_cloud_range=' + str(self.point_cloud_range)
s += ', average_points=' + str(self.average_points)
s += ')'
return s
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward'])
class _Voxelization(Function):
@staticmethod
def forward(ctx,
points,
voxel_size,
coors_range,
max_points=35,
max_voxels=20000):
"""Convert kitti points(N, >=3) to voxels.
Args:
points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points
and points[:, 3:] contain other information like reflectivity.
voxel_size (tuple or float): The size of voxel with the shape of
[3].
coors_range (tuple or float): The coordinate range of voxel with
the shape of [6].
max_points (int, optional): maximum points contained in a voxel. if
max_points=-1, it means using dynamic_voxelize. Default: 35.
max_voxels (int, optional): maximum voxels this function create.
for second, 20000 is a good choice. Users should shuffle points
before call this function because max_voxels may drop points.
Default: 20000.
Returns:
voxels_out (torch.Tensor): Output voxels with the shape of [M,
max_points, ndim]. Only contain points and returned when
max_points != -1.
coors_out (torch.Tensor): Output coordinates with the shape of
[M, 3].
num_points_per_voxel_out (torch.Tensor): Num points per voxel with
the shape of [M]. Only returned when max_points != -1.
"""
if max_points == -1 or max_voxels == -1:
coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int)
ext_module.dynamic_voxelize_forward(points, coors, voxel_size,
coors_range, 3)
return coors
else:
voxels = points.new_zeros(
size=(max_voxels, max_points, points.size(1)))
coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int)
num_points_per_voxel = points.new_zeros(
size=(max_voxels, ), dtype=torch.int)
voxel_num = ext_module.hard_voxelize_forward(
points, voxels, coors, num_points_per_voxel, voxel_size,
coors_range, max_points, max_voxels, 3)
# select the valid voxels
voxels_out = voxels[:voxel_num]
coors_out = coors[:voxel_num]
num_points_per_voxel_out = num_points_per_voxel[:voxel_num]
return voxels_out, coors_out, num_points_per_voxel_out
voxelization = _Voxelization.apply
class Voxelization(nn.Module):
"""Convert kitti points(N, >=3) to voxels.
Please refer to `PVCNN <https://arxiv.org/abs/1907.03739>`_ for more
details.
Args:
voxel_size (tuple or float): The size of voxel with the shape of [3].
point_cloud_range (tuple or float): The coordinate range of voxel with
the shape of [6].
max_num_points (int): maximum points contained in a voxel. if
max_points=-1, it means using dynamic_voxelize.
max_voxels (int, optional): maximum voxels this function create.
for second, 20000 is a good choice. Users should shuffle points
before call this function because max_voxels may drop points.
Default: 20000.
"""
def __init__(self,
voxel_size,
point_cloud_range,
max_num_points,
max_voxels=20000):
super().__init__()
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
self.max_num_points = max_num_points
if isinstance(max_voxels, tuple):
self.max_voxels = max_voxels
else:
self.max_voxels = _pair(max_voxels)
point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32)
voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
grid_size = (point_cloud_range[3:] -
point_cloud_range[:3]) / voxel_size
grid_size = torch.round(grid_size).long()
input_feat_shape = grid_size[:2]
self.grid_size = grid_size
# the origin shape is as [x-len, y-len, z-len]
# [w, h, d] -> [d, h, w]
self.pcd_shape = [*input_feat_shape, 1][::-1]
def forward(self, input):
if self.training:
max_voxels = self.max_voxels[0]
else:
max_voxels = self.max_voxels[1]
return voxelization(input, self.voxel_size, self.point_cloud_range,
self.max_num_points, max_voxels)
def __repr__(self):
s = self.__class__.__name__ + '('
s += 'voxel_size=' + str(self.voxel_size)
s += ', point_cloud_range=' + str(self.point_cloud_range)
s += ', max_num_points=' + str(self.max_num_points)
s += ', max_voxels=' + str(self.max_voxels)
s += ')'
return s
import pytest
import torch
from torch.autograd import gradcheck
from mmcv.ops import DynamicScatter
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_dynamic_scatter():
feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
coors = torch.randint(
low=-1, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
dsmean = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], True)
dsmax = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], False)
# test empty input
empty_feats = torch.empty(size=(0, 3), dtype=torch.float32, device='cuda')
empty_coors = torch.empty(size=(0, 3), dtype=torch.int32, device='cuda')
empty_feats.requires_grad_()
empty_feats_out_mean, empty_coors_out_mean = dsmean(
empty_feats, empty_coors)
empty_feats_out_mean.sum().backward()
empty_feats_out_max, empty_coors_out_max = dsmax(empty_feats, empty_coors)
empty_feats_out_max.sum().backward()
assert empty_feats_out_mean.shape == empty_feats.shape
assert empty_feats_out_max.shape == empty_feats.shape
assert empty_coors_out_mean.shape == empty_coors.shape
assert empty_coors_out_max.shape == empty_coors.shape
# test empty reduced output
empty_o_feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
empty_o_coors = torch.randint(
low=-1, high=0, size=(200000, 3), dtype=torch.int32, device='cuda')
empty_o_feats.requires_grad_()
empty_o_feats_out_mean, empty_o_coors_out_mean = dsmean(
empty_o_feats, empty_o_coors)
empty_o_feats_out_mean.sum().backward()
assert (empty_o_feats.grad == 0).all()
empty_o_feats_out_max, empty_o_coors_out_max = dsmax(
empty_o_feats, empty_o_coors)
empty_o_feats_out_max.sum().backward()
assert (empty_o_feats.grad == 0).all()
# test non-empty input
ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
ref_voxel_feats_mean = []
ref_voxel_feats_max = []
for ref_voxel_coor in ref_voxel_coors:
voxel_mask = (coors == ref_voxel_coor).all(dim=-1)
ref_voxel_feats_mean.append(feats[voxel_mask].mean(dim=0))
ref_voxel_feats_max.append(feats[voxel_mask].max(dim=0).values)
ref_voxel_feats_mean = torch.stack(ref_voxel_feats_mean)
ref_voxel_feats_max = torch.stack(ref_voxel_feats_max)
feats_out_mean, coors_out_mean = dsmean(feats, coors)
seq_mean = (coors_out_mean[:, 0] * 400 + coors_out_mean[:, 1] * 20 +
coors_out_mean[:, 2]).argsort()
feats_out_mean = feats_out_mean[seq_mean]
coors_out_mean = coors_out_mean[seq_mean]
feats_out_max, coors_out_max = dsmax(feats, coors)
seq_max = (coors_out_max[:, 0] * 400 + coors_out_max[:, 1] * 20 +
coors_out_max[:, 2]).argsort()
feats_out_max = feats_out_max[seq_max]
coors_cout_max = coors_out_max[seq_max]
assert (coors_out_mean == ref_voxel_coors).all()
assert torch.allclose(
feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5)
assert (coors_cout_max == ref_voxel_coors).all()
assert torch.allclose(
feats_out_max, ref_voxel_feats_max, atol=1e-2, rtol=1e-5)
# test grad #
feats = torch.rand(
size=(100, 4), dtype=torch.float32, device='cuda') * 100 - 50
coors = torch.randint(
low=-1, high=3, size=(100, 3), dtype=torch.int32, device='cuda')
feats.requires_grad_()
gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
gradcheck(dsmax, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
import numpy as np
import pytest
import torch
from mmcv.ops import Voxelization
def _get_voxel_points_indices(points, coors, voxel):
result_form = np.equal(coors, voxel)
return result_form[:, 0] & result_form[:, 1] & result_form[:, 2]
@pytest.mark.parametrize('device_type', [
'cpu',
pytest.param(
'cuda:0',
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support'))
])
def test_voxelization(device_type):
voxel_size = [0.5, 0.5, 0.5]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
voxel_dict = np.load(
'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item()
expected_coors = voxel_dict['coors']
expected_voxels = voxel_dict['voxels']
expected_num_points_per_voxel = voxel_dict['num_points_per_voxel']
points = voxel_dict['points']
points = torch.tensor(points)
max_num_points = -1
dynamic_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)
max_num_points = 1000
hard_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)
device = torch.device(device_type)
# test hard_voxelization on cpu/gpu
points = torch.tensor(points).contiguous().to(device)
coors, voxels, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy()
voxels = voxels.cpu().detach().numpy()
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy()
assert np.all(coors == expected_coors)
assert np.all(voxels == expected_voxels)
assert np.all(num_points_per_voxel == expected_num_points_per_voxel)
# test dynamic_voxelization on cpu/gpu
coors = dynamic_voxelization.forward(points)
coors = coors.cpu().detach().numpy()
points = points.cpu().detach().numpy()
for i in range(expected_voxels.shape[0]):
indices = _get_voxel_points_indices(points, coors, expected_voxels[i])
num_points_current_voxel = points[indices].shape[0]
assert num_points_current_voxel > 0
assert np.all(
points[indices] == expected_coors[i][:num_points_current_voxel])
assert num_points_current_voxel == expected_num_points_per_voxel[i]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment