Unverified Commit 93597a53 authored by zhanggefan's avatar zhanggefan Committed by GitHub
Browse files

A faster & more memory-efficient implementation of DynamicScatter (#318)



* a faster & more memory-efficient implementation of DynamicScatter

* fix format issues and add pytest skip code for tests on machines without cuda support

* some trivial changes:

decrease the number of kernel threads per block to 512, to enable inference on GPUs with computing capability lower than 2.0

change the backpropagation behavior of max-reduction. when there are multiple points shares the same maximum feature value, only the first point (with lowest row index) among them is chosen to propagate the output gradient back. before this change, all points with the same maximum feature value can propagate the output gradient back. this change makes the max-reduction behaves in consistence with torch.max. this change may cause gradcheck failure in test_dynamic_scatter.py. please do not worry about it because torch.max fails the gradcheck too.

* fix typo
Co-authored-by: default avatarzhanggefan <1152009@tongji.edu.cn>
parent 8214a977
......@@ -9,57 +9,41 @@ from .voxel_layer import (dynamic_point_to_voxel_backward,
class _dynamic_scatter(Function):
@staticmethod
def forward(ctx, points, coors, voxel_size, coors_range):
def forward(ctx, feats, coors, reduce_type='max'):
"""convert kitti points(N, >=3) to voxels.
Args:
points: [N, ndim] float tensor. points[:, :3] contain xyz
points and points[:, 3:] contain other information
such as reflectivity.
voxel_size: [3] list/tuple or array, float. xyz, indicate
voxel size
coors_range: [6] list/tuple or array, float. indicate voxel range.
format: xyzxyz, minmax
max_points: int. indicate maximum points contained in a voxel.
if max_points=-1, it means using dynamic_voxelize
max_voxels: int. indicate maximum voxels this function create.
for second, 20000 is a good choice. you should shuffle
points before call this function because max_voxels may
drop some points.
feats: [N, C] float tensor. points features to be reduced
into voxels.
coors: [N, ndim] int tensor. corresponding voxel coordinates
(specifically multi-dim voxel index) of each points.
reduce_type: str. reduce op. support 'max', 'sum' and 'mean'
Returns:
tuple
voxels: [M, max_points, ndim] float tensor. only contain points
and returned when max_points != -1.
coordinates: [M, 3] int32 tensor, always returned.
num_points_per_voxel: [M] int32 tensor. Only returned when
max_points != -1.
voxel_feats: [M, C] float tensor. reduced features. input features
that shares the same voxel coordinates are reduced to one row
coordinates: [M, ndim] int tensor, voxel coordinates.
"""
results = dynamic_point_to_voxel_forward(points, coors, voxel_size,
coors_range)
(voxels, voxel_coors, num_points_per_voxel, point_to_voxelidx,
coor_to_voxelidx) = results
ctx.save_for_backward(num_points_per_voxel, point_to_voxelidx,
coor_to_voxelidx)
return voxels, voxel_coors, num_points_per_voxel.float()
results = dynamic_point_to_voxel_forward(feats, coors, reduce_type)
(voxel_feats, voxel_coors, point2voxel_map,
voxel_points_count) = results
ctx.reduce_type = reduce_type
ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
voxel_points_count)
return voxel_feats, voxel_coors
@staticmethod
def backward(ctx,
grad_output_voxel,
grad_output_voxel_coors=None,
grad_output_num_points=None):
(num_points_per_voxel, point_to_voxelidx,
coor_to_voxelidx) = ctx.saved_tensors
# grad_output_voxel shape: NxMxC
num_points = point_to_voxelidx.size(0)
num_features = grad_output_voxel.size(-1)
grad_points = grad_output_voxel.new_zeros(
size=(num_points, num_features))
def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
(feats, voxel_feats, point2voxel_map,
voxel_points_count) = ctx.saved_tensors
grad_feats = torch.zeros_like(feats)
# TODO: whether to use index put or use cuda_backward
# To use index put, need point to voxel index
dynamic_point_to_voxel_backward(grad_points,
grad_output_voxel.contiguous(),
point_to_voxelidx, coor_to_voxelidx)
return grad_points, None, None, None
dynamic_point_to_voxel_backward(grad_feats,
grad_voxel_feats.contiguous(), feats,
voxel_feats, point2voxel_map,
voxel_points_count, ctx.reduce_type)
return grad_feats, None, None
dynamic_scatter = _dynamic_scatter.apply
......@@ -87,15 +71,8 @@ class DynamicScatter(nn.Module):
self.average_points = average_points
def forward_single(self, points, coors):
voxels, voxel_coors, num_points = dynamic_scatter(
points.contiguous(), coors.contiguous(), self.voxel_size,
self.point_cloud_range)
if not self.average_points:
voxels = torch.max(voxels, dim=1)[0] # voxels: NxMxC -> NxC
else:
voxels = (
voxels.sum(dim=1, keepdim=False).div(num_points.view(-1, 1)))
return voxels, voxel_coors
reduce = 'mean' if self.average_points else 'max'
return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
def forward(self, points, coors):
"""
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include "voxelization.h"
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#define CHECK_CUDA(x) \
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
namespace {
int const threadsPerBlock = sizeof(unsigned long long) * 8;
int const threadsPerBlock = 512;
int const maxGridDim = 50000;
} // namespace
__device__ __forceinline__ static void reduceMax(float *address, float val) {
int *address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
old = atomicCAS(address_as_i, assumed,
__float_as_int(fmaxf(val, __int_as_float(assumed))));
} while (assumed != old || __int_as_float(old) < val);
}
template <typename T, typename T_int>
__global__ void scatter_point_to_voxel_kernel(
const T* points, T_int* coor, T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T* voxels, T_int* coors, const int num_features,
const int num_points, const int max_points, const int NDim) {
const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
if (index >= num_points) return;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num > -1 && voxelidx > -1) {
const int feature_per_thread = 1;
int start = threadIdx.y * feature_per_thread;
auto voxels_offset =
voxels + voxelidx * max_points * num_features + num * num_features;
auto points_offset = points + index * num_features;
for (int k = start; k < start + feature_per_thread; k++) {
voxels_offset[k] = points_offset[k];
__device__ __forceinline__ static void reduceMax(double *address, double val) {
unsigned long long *address_as_ull =
reinterpret_cast<unsigned long long *>(address);
unsigned long long old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(
address_as_ull, assumed,
__double_as_longlong(fmax(val, __longlong_as_double(assumed))));
} while (assumed != old || __longlong_as_double(old) < val);
}
// get rid of meaningless warnings when compiling host code
#ifdef __CUDA_ARCH__
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
#if (__CUDA_ARCH__ < 200)
#warning \
"compute capability lower than 2.x. fall back to use CAS version of atomicAdd for float32"
int *address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;
do {
assumed = old;
old = atomicCAS(address_as_i, assumed,
__float_as_int(val + __int_as_float(assumed)));
} while (assumed != old);
#else
atomicAdd(address, val);
#endif
}
__device__ __forceinline__ static void reduceAdd(double *address, double val) {
#if (__CUDA_ARCH__ < 600)
#warning \
"compute capability lower than 6.x. fall back to use CAS version of atomicAdd for float64"
unsigned long long *address_as_ull =
reinterpret_cast<unsigned long long *>(address);
unsigned long long old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
#else
atomicAdd(address, val);
#endif
}
#endif
template <typename T_int>
__global__ void coors_id_kernel(const T_int *coors, const T_int *dim,
int64_t *coors_id, const int num_input,
const int NDim) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input;
x += gridDim.x * blockDim.x) {
const T_int *coor_x = coors + x * NDim;
auto coor_id = 0;
for (int i = 0; i < NDim && coor_id != -1; i++) {
coor_id *= dim[i];
auto t = static_cast<int64_t>(coor_x[i]);
coor_id = (t < 0) ? -1 : coor_id + t;
}
if (num == 0 && start < NDim) {
auto coors_offset = coors + voxelidx * NDim;
auto coor_offset = coor + index * NDim;
for (int k = start; k < NDim; k++) {
coors_offset[k] = coor_offset[k];
coors_id[x] = coor_id;
}
}
template <typename T_int>
__global__ void coors_map_init_kernel(const int64_t *coors_id,
const T_int *coors_id_argsort,
int32_t *coors_map, const int num_input) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input;
x += gridDim.x * blockDim.x) {
auto here = coors_id[coors_id_argsort[x]];
if (x == 0) {
if (here == -1) { // there is invalid points
coors_map[0] = -1;
} else {
coors_map[0] = 0;
}
return;
}
auto left = coors_id[coors_id_argsort[x - 1]];
coors_map[x] = (left < here) ? 1 : 0;
}
}
template <typename T, typename T_int>
__global__ void map_voxel_to_point_kernel(
T* points, T* voxels, T_int* point_to_voxelidx, T_int* coor_to_voxelidx,
const int num_features, const int num_points, const int max_points) {
const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
if (index >= num_points) return;
auto num = point_to_voxelidx[index];
if (num > -1) {
const int feature_per_thread = 1;
auto voxelidx = coor_to_voxelidx[index];
int start = threadIdx.y * feature_per_thread;
auto voxels_offset =
voxels + voxelidx * max_points * num_features + num * num_features;
auto points_offset = points + index * num_features;
for (int k = start; k < start + feature_per_thread; k++) {
points_offset[k] = voxels_offset[k];
__global__ void
feats_reduce_kernel(const T *feats, const T_int *coors, int32_t *coors_map,
int32_t *reduce_count, // shall be 0 at initialization
T *reduced_feats, // shall be 0 at initialization
T_int *out_coors, const int num_input, const int num_feats,
const int NDim, const reduce_t reduce_type) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input;
x += gridDim.x * blockDim.x) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1)
return;
const T_int *coors_offset = coors + x * NDim;
T_int *out_coors_offset = out_coors + reduce_to * NDim;
for (int i = 0; i < NDim; i++) {
out_coors_offset[i] = coors_offset[i];
}
const T *feats_offset = feats + x * num_feats;
T *reduced_feats_offset = reduced_feats + reduce_to * num_feats;
if (reduce_type == reduce_t::MAX) {
for (int i = 0; i < num_feats; i++) {
reduceMax(&reduced_feats_offset[i], feats_offset[i]);
}
} else {
if (reduce_type == reduce_t::MEAN) {
atomicAdd(&reduce_count[reduce_to], static_cast<int32_t>(1));
}
for (int i = 0; i < num_feats; i++) {
reduceAdd(&reduced_feats_offset[i], feats_offset[i]);
}
}
}
}
template <typename T_int>
__global__ void point_to_voxelidx_kernel(const T_int* coor,
T_int* point_to_voxelidx,
T_int* point_to_pointidx,
const int num_points, const int NDim) {
const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
auto coor_offset = coor + index * NDim;
// skip invalid points
if ((index >= num_points) || (coor_offset[0] == -1)) return;
int num = 0;
int coor_x = coor_offset[0];
int coor_y = coor_offset[1];
int coor_z = coor_offset[2];
// only calculate the coors before this coor[index]
for (int i = 0; i < index; ++i) {
auto prev_coor = coor + i * NDim;
if (prev_coor[0] == -1) continue;
// record voxel
if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) &&
(prev_coor[2] == coor_z)) {
num++;
if (num == 1) {
point_to_pointidx[index] = i;
template <typename T>
__global__ void add_reduce_traceback_grad_kernel(
T *grad_feats, const T *grad_reduced_feats, const int32_t *coors_map,
const int32_t *reduce_count, const int num_input, const int num_feats,
const reduce_t reduce_type) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input;
x += gridDim.x * blockDim.x) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1) {
return;
}
const int input_offset = x * num_feats;
T *grad_feats_offset = grad_feats + input_offset;
const int reduced_offset = reduce_to * num_feats;
const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
if (reduce_type == reduce_t::SUM) {
for (int i = 0; i < num_feats; i++) {
grad_feats_offset[i] = grad_reduced_feats_offset[i];
}
} else if (reduce_type == reduce_t::MEAN) {
for (int i = 0; i < num_feats; i++) {
grad_feats_offset[i] = grad_reduced_feats_offset[i] /
static_cast<T>(reduce_count[reduce_to]);
}
}
}
if (num == 0) {
point_to_pointidx[index] = index;
}
point_to_voxelidx[index] = num;
}
template <typename T_int>
__global__ void determin_voxel_num(
const T_int* coor, T_int* num_points_per_voxel, T_int* point_to_voxelidx,
T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num,
T_int* max_points, const int num_points, const int NDim) {
// only calculate the coors before this coor[index]
for (int i = 0; i < num_points; ++i) {
auto coor_offset = coor + i * NDim;
if (coor_offset[0] == -1) continue;
int point_pos_in_voxel = point_to_voxelidx[i];
// record voxel
if (point_pos_in_voxel == -1) {
// out of max_points or invalid point
printf("point_pos_in_voxel == -1, point:%d", i);
continue;
} else if (point_pos_in_voxel == 0) {
// record new voxel
int voxelidx = voxel_num[0];
voxel_num[0] += 1;
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] = 1;
} else {
int point_idx = point_to_pointidx[i];
int voxelidx = coor_to_voxelidx[point_idx];
if (voxelidx != -1) {
num_points_per_voxel[voxelidx] += 1;
coor_to_voxelidx[i] = voxelidx;
max_points[0] = max(max_points[0], point_pos_in_voxel + 1);
} else {
printf("voxelidx = -1, point:%d", i);
template <typename T>
__global__ void max_reduce_traceback_scatter_idx_kernel(
const T *feats, const T *reduced_feats, int32_t *reduce_from,
const int32_t *coors_map, const int num_input, const int num_feats) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input;
x += gridDim.x * blockDim.x) {
int32_t reduce_to = coors_map[x];
const int input_offset = x * num_feats;
const T *feats_offset = feats + input_offset;
if (reduce_to == -1) {
return;
}
const int reduced_offset = reduce_to * num_feats;
const T *reduced_feats_offset = reduced_feats + reduced_offset;
int32_t *reduce_from_offset = reduce_from + reduced_offset;
for (int i = 0; i < num_feats; i++) {
if (feats_offset[i] == reduced_feats_offset[i]) {
atomicMin(&reduce_from_offset[i], static_cast<int32_t>(x));
}
}
}
}
template <typename T>
__global__ void
max_reduce_scatter_grad_kernel(T *grad_feats, const T *grad_reduced_feats,
const int32_t *reduce_from,
const int num_reduced, const int num_feats) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_reduced;
x += gridDim.x * blockDim.x) {
const int reduced_offset = x * num_feats;
const int32_t *scatter_to_offset = reduce_from + reduced_offset;
const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset;
for (int i = 0; i < num_feats; i++) {
grad_feats[scatter_to_offset[i] * num_feats + i] =
grad_reduced_feats_offset[i];
}
}
}
namespace voxelization {
std::vector<at::Tensor> dynamic_point_to_voxel_forward_gpu(
const at::Tensor& points, const at::Tensor& voxel_mapping,
const std::vector<float> voxel_size, const std::vector<float> coors_range) {
CHECK_INPUT(points);
at::cuda::CUDAGuard device_guard(points.device());
std::vector<torch::Tensor>
dynamic_point_to_voxel_forward_gpu(const torch::Tensor &feats,
const torch::Tensor &coors,
const reduce_t reduce_type) {
CHECK_INPUT(feats);
CHECK_INPUT(coors);
const int NDim = voxel_mapping.size(1);
const int num_points = points.size(0);
const int num_features = points.size(1);
const int NDim = coors.size(1);
const int num_input = feats.size(0);
const int num_feats = feats.size(1);
std::vector<int> grid_size(NDim);
for (int i = 0; i < NDim; ++i) {
grid_size[i] =
round((coors_range[NDim + i] - coors_range[i]) / voxel_size[i]);
}
auto coors_id = torch::empty({num_input}, coors.options().dtype(torch::kI64));
auto coor_space_dim = coors.max_values(0) + 1;
auto coors_map_sorted =
torch::empty({num_input}, coors.options().dtype(torch::kI32));
auto coors_map =
torch::empty({num_input}, coors.options().dtype(torch::kI32));
auto num_coors = at::zeros({1}, coors.options().dtype(torch::kI32));
// assume the mapping is already given
auto point_to_pointidx = -at::ones(
{
num_points,
},
voxel_mapping.options());
auto point_to_voxelidx = -at::ones(
{
num_points,
},
voxel_mapping.options());
auto max_points = at::zeros(
{
1,
},
voxel_mapping.options()); // must be zero from the begining
int col_blocks = at::cuda::ATenCeilDiv(num_points, threadsPerBlock);
dim3 blocks(col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t map_stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(
voxel_mapping.scalar_type(), "determin_duplicate", ([&] {
point_to_voxelidx_kernel<int><<<blocks, threads, 0, map_stream>>>(
voxel_mapping.data_ptr<int>(), point_to_voxelidx.data_ptr<int>(),
point_to_pointidx.data_ptr<int>(), num_points, NDim);
AT_DISPATCH_INTEGRAL_TYPES(
coors.scalar_type(), "coors_id_kernel", ([&] {
dim3 blocks(std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 threads(threadsPerBlock);
coors_id_kernel<<<blocks, threads>>>(
coors.data_ptr<scalar_t>(), coor_space_dim.data_ptr<scalar_t>(),
coors_id.data_ptr<int64_t>(), num_input, NDim);
}));
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
// make the logic in the CUDA device could accelerate about 10 times
auto num_points_per_voxel = at::zeros(
{
num_points,
},
voxel_mapping.options());
auto coor_to_voxelidx = -at::ones(
{
num_points,
},
voxel_mapping.options());
auto voxel_num = at::zeros(
{
1,
},
voxel_mapping.options()); // must be zero from the begining
cudaStream_t logic_stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(
voxel_mapping.scalar_type(), "determin_duplicate", ([&] {
determin_voxel_num<int><<<1, 1, 0, logic_stream>>>(
voxel_mapping.data_ptr<int>(), num_points_per_voxel.data_ptr<int>(),
point_to_voxelidx.data_ptr<int>(),
point_to_pointidx.data_ptr<int>(), coor_to_voxelidx.data_ptr<int>(),
voxel_num.data_ptr<int>(), max_points.data_ptr<int>(), num_points,
NDim);
auto coors_id_argsort = coors_id.argsort();
AT_DISPATCH_INTEGRAL_TYPES(
coors_id_argsort.scalar_type(), "coors_map_init_kernel", ([&] {
dim3 blocks(std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 threads(threadsPerBlock);
coors_map_init_kernel<<<blocks, threads>>>(
coors_id.data_ptr<int64_t>(), coors_id_argsort.data_ptr<scalar_t>(),
coors_map_sorted.data_ptr<int32_t>(), num_input);
}));
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
// some temporary data
auto max_points_cpu = max_points.to(at::kCPU);
int max_points_int = max_points_cpu.data_ptr<int>()[0];
auto voxel_num_cpu = voxel_num.to(at::kCPU);
int voxel_num_int = voxel_num_cpu.data_ptr<int>()[0];
at::Tensor coors =
at::zeros({voxel_num_int, NDim}, points.options().dtype(at::kInt));
at::Tensor voxels = at::zeros({voxel_num_int, max_points_int, num_features},
points.options());
// copy point features to voxels
dim3 cp_threads(threadsPerBlock, num_features);
cudaStream_t cp_stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "scatter_point_to_voxel", ([&] {
scatter_point_to_voxel_kernel<float, int>
<<<blocks, cp_threads, 0, cp_stream>>>(
points.data_ptr<float>(), voxel_mapping.data_ptr<int>(),
point_to_voxelidx.data_ptr<int>(),
coor_to_voxelidx.data_ptr<int>(), voxels.data_ptr<float>(),
coors.data_ptr<int>(), num_features, num_points, max_points_int,
NDim);
coors_map_sorted = coors_map_sorted.cumsum(0, torch::kI32);
coors_map.index_put_(coors_id_argsort, coors_map_sorted);
const int num_coors_cpu =
coors_map_sorted[-1].cpu().data_ptr<int32_t>()[0] + 1;
auto out_coors = torch::empty({num_coors_cpu, NDim}, coors.options());
auto reduced_feats =
torch::empty({num_coors_cpu, num_feats}, feats.options());
auto reduce_count =
torch::zeros({num_coors_cpu}, coors.options().dtype(torch::kI32));
AT_DISPATCH_FLOATING_TYPES(
feats.scalar_type(), "feats_reduce_kernel", ([&] {
using F_t = scalar_t;
AT_DISPATCH_INTEGRAL_TYPES(
coors.scalar_type(), "feats_reduce_kernel", ([&] {
using I_t = scalar_t;
if (reduce_type == reduce_t::MAX)
reduced_feats.fill_(-std::numeric_limits<F_t>::infinity());
else
reduced_feats.fill_(static_cast<F_t>(0));
dim3 blocks(
std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 threads(threadsPerBlock);
feats_reduce_kernel<<<blocks, threads>>>(
feats.data_ptr<F_t>(), coors.data_ptr<I_t>(),
coors_map.data_ptr<int32_t>(),
reduce_count.data_ptr<int32_t>(),
reduced_feats.data_ptr<F_t>(), out_coors.data_ptr<I_t>(),
num_input, num_feats, NDim, reduce_type);
if (reduce_type == reduce_t::MEAN)
reduced_feats /=
reduce_count.unsqueeze(-1).to(reduced_feats.dtype());
}));
}));
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
at::Tensor num_points_per_voxel_out =
num_points_per_voxel.slice(/*dim=*/0, /*start=*/0, /*end=*/voxel_num_int);
return {voxels, coors, num_points_per_voxel_out, point_to_voxelidx,
coor_to_voxelidx};
return {reduced_feats, out_coors, coors_map, reduce_count};
}
void dynamic_point_to_voxel_backward_gpu(at::Tensor& grad_input_points,
const at::Tensor& grad_output_voxels,
const at::Tensor& point_to_voxelidx,
const at::Tensor& coor_to_voxelidx) {
CHECK_INPUT(grad_input_points);
CHECK_INPUT(grad_output_voxels);
CHECK_INPUT(point_to_voxelidx);
CHECK_INPUT(coor_to_voxelidx);
at::cuda::CUDAGuard device_guard(grad_input_points.device());
void dynamic_point_to_voxel_backward_gpu(
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
const torch::Tensor &coors_map, const torch::Tensor &reduce_count,
const reduce_t reduce_type) {
CHECK_INPUT(grad_feats);
CHECK_INPUT(grad_reduced_feats);
CHECK_INPUT(feats);
CHECK_INPUT(reduced_feats);
CHECK_INPUT(coors_map);
CHECK_INPUT(reduce_count);
const int num_points = grad_input_points.size(0);
const int num_features = grad_input_points.size(1);
const int max_points = grad_output_voxels.size(1);
const int num_input = feats.size(0);
const int num_reduced = reduced_feats.size(0);
const int num_feats = feats.size(1);
grad_feats.fill_(0);
// copy voxel grad to points
int col_blocks = at::cuda::ATenCeilDiv(num_points, threadsPerBlock);
dim3 blocks(col_blocks);
dim3 cp_threads(threadsPerBlock, num_features);
cudaStream_t cp_stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(grad_input_points.scalar_type(),
"scatter_point_to_voxel", ([&] {
map_voxel_to_point_kernel<float, int>
<<<blocks, cp_threads, 0, cp_stream>>>(
grad_input_points.data_ptr<float>(),
grad_output_voxels.data_ptr<float>(),
point_to_voxelidx.data_ptr<int>(),
coor_to_voxelidx.data_ptr<int>(),
num_features, num_points, max_points);
}));
cudaDeviceSynchronize();
AT_CUDA_CHECK(cudaGetLastError());
if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) {
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel",
([&] {
dim3 blocks(std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 threads(threadsPerBlock);
add_reduce_traceback_grad_kernel<<<blocks, threads>>>(
grad_feats.data_ptr<scalar_t>(),
grad_reduced_feats.data_ptr<scalar_t>(),
coors_map.data_ptr<int32_t>(), reduce_count.data_ptr<int32_t>(),
num_input, num_feats, reduce_type);
}));
AT_CUDA_CHECK(cudaGetLastError());
} else {
auto reduce_from = torch::full({num_reduced, num_feats}, num_input,
coors_map.options().dtype(torch::kI32));
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(),
"max_reduce_traceback_scatter_idx_kernel", ([&] {
dim3 blocks(std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 threads(threadsPerBlock);
max_reduce_traceback_scatter_idx_kernel<<<blocks, threads>>>(
feats.data_ptr<scalar_t>(), reduced_feats.data_ptr<scalar_t>(),
reduce_from.data_ptr<int32_t>(), coors_map.data_ptr<int32_t>(),
num_input, num_feats);
}));
AT_CUDA_CHECK(cudaGetLastError());
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(),
"max_reduce_traceback_scatter_idx_kernel", ([&] {
dim3 blocks(
std::min(DIVUP(num_reduced, threadsPerBlock), maxGridDim));
dim3 threads(threadsPerBlock);
max_reduce_scatter_grad_kernel<<<blocks, threads>>>(
grad_feats.data_ptr<scalar_t>(),
grad_reduced_feats.data_ptr<scalar_t>(),
reduce_from.data_ptr<int32_t>(), num_reduced, num_feats);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
return;
}
} // namespace voxelization
} // namespace voxelization
#pragma once
#include <torch/extension.h>
typedef enum { SUM, MEAN, MAX } reduce_t;
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
namespace voxelization {
int hard_voxelize_cpu(const at::Tensor& points, at::Tensor& voxels,
at::Tensor& coors, at::Tensor& num_points_per_voxel,
int hard_voxelize_cpu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3);
void dynamic_voxelize_cpu(const at::Tensor& points, at::Tensor& coors,
void dynamic_voxelize_cpu(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3);
std::vector<at::Tensor> dynamic_point_to_voxel_cpu(
const at::Tensor& points, const at::Tensor& voxel_mapping,
const at::Tensor &points, const at::Tensor &voxel_mapping,
const std::vector<float> voxel_size, const std::vector<float> coors_range);
#ifdef WITH_CUDA
int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels,
at::Tensor& coors, at::Tensor& num_points_per_voxel,
int hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3);
void dynamic_voxelize_gpu(const at::Tensor& points, at::Tensor& coors,
void dynamic_voxelize_gpu(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3);
std::vector<at::Tensor> dynamic_point_to_voxel_forward_gpu(
const at::Tensor& points, const at::Tensor& voxel_mapping,
const std::vector<float> voxel_size, const std::vector<float> coors_range);
std::vector<torch::Tensor> dynamic_point_to_voxel_forward_gpu(const torch::Tensor &feats,
const torch::Tensor &coors,
const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_gpu(at::Tensor& grad_input_points,
const at::Tensor& grad_output_voxels,
const at::Tensor& point_to_voxelidx,
const at::Tensor& coor_to_voxelidx);
void dynamic_point_to_voxel_backward_gpu(torch::Tensor &grad_feats,
const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats,
const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx,
const torch::Tensor &reduce_count,
const reduce_t reduce_type);
#endif
// Interface for Python
inline int hard_voxelize(const at::Tensor& points, at::Tensor& voxels,
at::Tensor& coors, at::Tensor& num_points_per_voxel,
inline int hard_voxelize(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
......@@ -63,7 +69,7 @@ inline int hard_voxelize(const at::Tensor& points, at::Tensor& voxels,
NDim);
}
inline void dynamic_voxelize(const at::Tensor& points, at::Tensor& coors,
inline void dynamic_voxelize(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim = 3) {
......@@ -77,37 +83,49 @@ inline void dynamic_voxelize(const at::Tensor& points, at::Tensor& coors,
return dynamic_voxelize_cpu(points, coors, voxel_size, coors_range, NDim);
}
inline std::vector<torch::Tensor> dynamic_point_to_voxel_forward(
const at::Tensor& points, const at::Tensor& voxel_mapping,
const std::vector<float> voxel_size, const std::vector<float> coors_range) {
if (points.device().is_cuda()) {
inline reduce_t convert_reduce_type(const std::string &reduce_type) {
if (reduce_type == "max")
return reduce_t::MAX;
else if (reduce_type == "sum")
return reduce_t::SUM;
else if (reduce_type == "mean")
return reduce_t::MEAN;
else TORCH_CHECK(false, "do not support reduce type " + reduce_type)
return reduce_t::SUM;
}
inline std::vector<torch::Tensor> dynamic_point_to_voxel_forward(const torch::Tensor &feats,
const torch::Tensor &coors,
const std::string &reduce_type) {
if (feats.device().is_cuda()) {
#ifdef WITH_CUDA
return dynamic_point_to_voxel_forward_gpu(points, voxel_mapping, voxel_size,
coors_range);
return dynamic_point_to_voxel_forward_gpu(feats, coors, convert_reduce_type(reduce_type));
#else
AT_ERROR("Not compiled with GPU support");
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return dynamic_point_to_voxel_cpu(points, voxel_mapping, voxel_size,
coors_range);
TORCH_CHECK(false, "do not support cpu yet");
return std::vector<torch::Tensor>();
}
inline void dynamic_point_to_voxel_backward(
at::Tensor& grad_input_points, const at::Tensor& grad_output_voxels,
const at::Tensor& point_to_voxelidx, const at::Tensor& coor_to_voxelidx) {
if (grad_input_points.device().is_cuda()) {
inline void dynamic_point_to_voxel_backward(torch::Tensor &grad_feats,
const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats,
const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx,
const torch::Tensor &reduce_count,
const std::string &reduce_type) {
if (grad_feats.device().is_cuda()) {
#ifdef WITH_CUDA
return dynamic_point_to_voxel_backward_gpu(
grad_input_points, grad_output_voxels, point_to_voxelidx,
coor_to_voxelidx);
dynamic_point_to_voxel_backward_gpu(
grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, reduce_count,
convert_reduce_type(reduce_type));
return;
#else
AT_ERROR("Not compiled with GPU support");
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
// return dynamic_point_to_voxel_cpu(points,
// voxel_mapping,
// voxel_size,
// coors_range);
TORCH_CHECK(false, "do not support cpu yet");
}
} // namespace voxelization
import pytest
import torch
from torch.autograd import gradcheck
from mmdet3d.ops import DynamicScatter
def test_dynamic_scatter():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
coors = torch.randint(
low=-1, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
coors[coors.min(dim=-1).values < 0] = -1
dsmean = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], True)
dsmax = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], False)
ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
ref_voxel_feats_mean = []
ref_voxel_feats_max = []
for ref_voxel_coor in ref_voxel_coors:
voxel_mask = (coors == ref_voxel_coor).all(dim=-1)
ref_voxel_feats_mean.append(feats[voxel_mask].mean(dim=0))
ref_voxel_feats_max.append(feats[voxel_mask].max(dim=0).values)
ref_voxel_feats_mean = torch.stack(ref_voxel_feats_mean)
ref_voxel_feats_max = torch.stack(ref_voxel_feats_max)
feats_out_mean, coors_out_mean = dsmean(feats, coors)
seq_mean = (coors_out_mean[:, 0] * 400 + coors_out_mean[:, 1] * 20 +
coors_out_mean[:, 2]).argsort()
feats_out_mean = feats_out_mean[seq_mean]
coors_out_mean = coors_out_mean[seq_mean]
feats_out_max, coors_out_max = dsmax(feats, coors)
seq_max = (coors_out_max[:, 0] * 400 + coors_out_max[:, 1] * 20 +
coors_out_max[:, 2]).argsort()
feats_out_max = feats_out_max[seq_max]
coors_cout_max = coors_out_max[seq_max]
assert (coors_out_mean == ref_voxel_coors).all()
assert torch.allclose(
feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5)
assert (coors_cout_max == ref_voxel_coors).all()
assert torch.allclose(
feats_out_max, ref_voxel_feats_max, atol=1e-2, rtol=1e-5)
# test grad #
feats = torch.rand(
size=(100, 4), dtype=torch.float32, device='cuda') * 100 - 50
coors = torch.randint(
low=-1, high=3, size=(100, 3), dtype=torch.int32, device='cuda')
feats.requires_grad_()
gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
gradcheck(dsmax, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment