Unverified Commit 93597a53 authored by zhanggefan's avatar zhanggefan Committed by GitHub
Browse files

A faster & more memory-efficient implementation of DynamicScatter (#318)



* a faster & more memory-efficient implementation of DynamicScatter

* fix format issues and add pytest skip code for tests on machines without cuda support

* some trivial changes:

decrease the number of kernel threads per block to 512, to enable inference on GPUs with computing capability lower than 2.0

change the backpropagation behavior of max-reduction. when there are multiple points shares the same maximum feature value, only the first point (with lowest row index) among them is chosen to propagate the output gradient back. before this change, all points with the same maximum feature value can propagate the output gradient back. this change makes the max-reduction behaves in consistence with torch.max. this change may cause gradcheck failure in test_dynamic_scatter.py. please do not worry about it because torch.max fails the gradcheck too.

* fix typo
Co-authored-by: default avatarzhanggefan <1152009@tongji.edu.cn>
parent 8214a977
...@@ -9,57 +9,41 @@ from .voxel_layer import (dynamic_point_to_voxel_backward, ...@@ -9,57 +9,41 @@ from .voxel_layer import (dynamic_point_to_voxel_backward,
class _dynamic_scatter(Function): class _dynamic_scatter(Function):
@staticmethod @staticmethod
def forward(ctx, points, coors, voxel_size, coors_range): def forward(ctx, feats, coors, reduce_type='max'):
"""convert kitti points(N, >=3) to voxels. """convert kitti points(N, >=3) to voxels.
Args: Args:
points: [N, ndim] float tensor. points[:, :3] contain xyz feats: [N, C] float tensor. points features to be reduced
points and points[:, 3:] contain other information into voxels.
such as reflectivity. coors: [N, ndim] int tensor. corresponding voxel coordinates
voxel_size: [3] list/tuple or array, float. xyz, indicate (specifically multi-dim voxel index) of each points.
voxel size reduce_type: str. reduce op. support 'max', 'sum' and 'mean'
coors_range: [6] list/tuple or array, float. indicate voxel range.
format: xyzxyz, minmax
max_points: int. indicate maximum points contained in a voxel.
if max_points=-1, it means using dynamic_voxelize
max_voxels: int. indicate maximum voxels this function create.
for second, 20000 is a good choice. you should shuffle
points before call this function because max_voxels may
drop some points.
Returns: Returns:
tuple tuple
voxels: [M, max_points, ndim] float tensor. only contain points voxel_feats: [M, C] float tensor. reduced features. input features
and returned when max_points != -1. that shares the same voxel coordinates are reduced to one row
coordinates: [M, 3] int32 tensor, always returned. coordinates: [M, ndim] int tensor, voxel coordinates.
num_points_per_voxel: [M] int32 tensor. Only returned when
max_points != -1.
""" """
results = dynamic_point_to_voxel_forward(points, coors, voxel_size, results = dynamic_point_to_voxel_forward(feats, coors, reduce_type)
coors_range) (voxel_feats, voxel_coors, point2voxel_map,
(voxels, voxel_coors, num_points_per_voxel, point_to_voxelidx, voxel_points_count) = results
coor_to_voxelidx) = results ctx.reduce_type = reduce_type
ctx.save_for_backward(num_points_per_voxel, point_to_voxelidx, ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
coor_to_voxelidx) voxel_points_count)
return voxels, voxel_coors, num_points_per_voxel.float() return voxel_feats, voxel_coors
@staticmethod @staticmethod
def backward(ctx, def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
grad_output_voxel, (feats, voxel_feats, point2voxel_map,
grad_output_voxel_coors=None, voxel_points_count) = ctx.saved_tensors
grad_output_num_points=None): grad_feats = torch.zeros_like(feats)
(num_points_per_voxel, point_to_voxelidx,
coor_to_voxelidx) = ctx.saved_tensors
# grad_output_voxel shape: NxMxC
num_points = point_to_voxelidx.size(0)
num_features = grad_output_voxel.size(-1)
grad_points = grad_output_voxel.new_zeros(
size=(num_points, num_features))
# TODO: whether to use index put or use cuda_backward # TODO: whether to use index put or use cuda_backward
# To use index put, need point to voxel index # To use index put, need point to voxel index
dynamic_point_to_voxel_backward(grad_points, dynamic_point_to_voxel_backward(grad_feats,
grad_output_voxel.contiguous(), grad_voxel_feats.contiguous(), feats,
point_to_voxelidx, coor_to_voxelidx) voxel_feats, point2voxel_map,
return grad_points, None, None, None voxel_points_count, ctx.reduce_type)
return grad_feats, None, None
dynamic_scatter = _dynamic_scatter.apply dynamic_scatter = _dynamic_scatter.apply
...@@ -87,15 +71,8 @@ class DynamicScatter(nn.Module): ...@@ -87,15 +71,8 @@ class DynamicScatter(nn.Module):
self.average_points = average_points self.average_points = average_points
def forward_single(self, points, coors): def forward_single(self, points, coors):
voxels, voxel_coors, num_points = dynamic_scatter( reduce = 'mean' if self.average_points else 'max'
points.contiguous(), coors.contiguous(), self.voxel_size, return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
self.point_cloud_range)
if not self.average_points:
voxels = torch.max(voxels, dim=1)[0] # voxels: NxMxC -> NxC
else:
voxels = (
voxels.sum(dim=1, keepdim=False).div(num_points.view(-1, 1)))
return voxels, voxel_coors
def forward(self, points, coors): def forward(self, points, coors):
""" """
......
#pragma once #pragma once
#include <torch/extension.h> #include <torch/extension.h>
typedef enum { SUM, MEAN, MAX } reduce_t;
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
namespace voxelization { namespace voxelization {
int hard_voxelize_cpu(const at::Tensor& points, at::Tensor& voxels, int hard_voxelize_cpu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor& coors, at::Tensor& num_points_per_voxel, at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
const int max_points, const int max_voxels, const int max_points, const int max_voxels,
const int NDim = 3); const int NDim = 3);
void dynamic_voxelize_cpu(const at::Tensor& points, at::Tensor& coors, void dynamic_voxelize_cpu(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
const int NDim = 3); const int NDim = 3);
std::vector<at::Tensor> dynamic_point_to_voxel_cpu( std::vector<at::Tensor> dynamic_point_to_voxel_cpu(
const at::Tensor& points, const at::Tensor& voxel_mapping, const at::Tensor &points, const at::Tensor &voxel_mapping,
const std::vector<float> voxel_size, const std::vector<float> coors_range); const std::vector<float> voxel_size, const std::vector<float> coors_range);
#ifdef WITH_CUDA #ifdef WITH_CUDA
int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels, int hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor& coors, at::Tensor& num_points_per_voxel, at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
const int max_points, const int max_voxels, const int max_points, const int max_voxels,
const int NDim = 3); const int NDim = 3);
void dynamic_voxelize_gpu(const at::Tensor& points, at::Tensor& coors, void dynamic_voxelize_gpu(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
const int NDim = 3); const int NDim = 3);
std::vector<at::Tensor> dynamic_point_to_voxel_forward_gpu( std::vector<torch::Tensor> dynamic_point_to_voxel_forward_gpu(const torch::Tensor &feats,
const at::Tensor& points, const at::Tensor& voxel_mapping, const torch::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range); const reduce_t reduce_type);
void dynamic_point_to_voxel_backward_gpu(at::Tensor& grad_input_points, void dynamic_point_to_voxel_backward_gpu(torch::Tensor &grad_feats,
const at::Tensor& grad_output_voxels, const torch::Tensor &grad_reduced_feats,
const at::Tensor& point_to_voxelidx, const torch::Tensor &feats,
const at::Tensor& coor_to_voxelidx); const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx,
const torch::Tensor &reduce_count,
const reduce_t reduce_type);
#endif #endif
// Interface for Python // Interface for Python
inline int hard_voxelize(const at::Tensor& points, at::Tensor& voxels, inline int hard_voxelize(const at::Tensor &points, at::Tensor &voxels,
at::Tensor& coors, at::Tensor& num_points_per_voxel, at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
const int max_points, const int max_voxels, const int max_points, const int max_voxels,
...@@ -63,7 +69,7 @@ inline int hard_voxelize(const at::Tensor& points, at::Tensor& voxels, ...@@ -63,7 +69,7 @@ inline int hard_voxelize(const at::Tensor& points, at::Tensor& voxels,
NDim); NDim);
} }
inline void dynamic_voxelize(const at::Tensor& points, at::Tensor& coors, inline void dynamic_voxelize(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
const int NDim = 3) { const int NDim = 3) {
...@@ -77,37 +83,49 @@ inline void dynamic_voxelize(const at::Tensor& points, at::Tensor& coors, ...@@ -77,37 +83,49 @@ inline void dynamic_voxelize(const at::Tensor& points, at::Tensor& coors,
return dynamic_voxelize_cpu(points, coors, voxel_size, coors_range, NDim); return dynamic_voxelize_cpu(points, coors, voxel_size, coors_range, NDim);
} }
inline std::vector<torch::Tensor> dynamic_point_to_voxel_forward( inline reduce_t convert_reduce_type(const std::string &reduce_type) {
const at::Tensor& points, const at::Tensor& voxel_mapping, if (reduce_type == "max")
const std::vector<float> voxel_size, const std::vector<float> coors_range) { return reduce_t::MAX;
if (points.device().is_cuda()) { else if (reduce_type == "sum")
return reduce_t::SUM;
else if (reduce_type == "mean")
return reduce_t::MEAN;
else TORCH_CHECK(false, "do not support reduce type " + reduce_type)
return reduce_t::SUM;
}
inline std::vector<torch::Tensor> dynamic_point_to_voxel_forward(const torch::Tensor &feats,
const torch::Tensor &coors,
const std::string &reduce_type) {
if (feats.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return dynamic_point_to_voxel_forward_gpu(points, voxel_mapping, voxel_size, return dynamic_point_to_voxel_forward_gpu(feats, coors, convert_reduce_type(reduce_type));
coors_range);
#else #else
AT_ERROR("Not compiled with GPU support"); TORCH_CHECK(false, "Not compiled with GPU support");
#endif #endif
} }
return dynamic_point_to_voxel_cpu(points, voxel_mapping, voxel_size, TORCH_CHECK(false, "do not support cpu yet");
coors_range); return std::vector<torch::Tensor>();
} }
inline void dynamic_point_to_voxel_backward( inline void dynamic_point_to_voxel_backward(torch::Tensor &grad_feats,
at::Tensor& grad_input_points, const at::Tensor& grad_output_voxels, const torch::Tensor &grad_reduced_feats,
const at::Tensor& point_to_voxelidx, const at::Tensor& coor_to_voxelidx) { const torch::Tensor &feats,
if (grad_input_points.device().is_cuda()) { const torch::Tensor &reduced_feats,
const torch::Tensor &coors_idx,
const torch::Tensor &reduce_count,
const std::string &reduce_type) {
if (grad_feats.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return dynamic_point_to_voxel_backward_gpu( dynamic_point_to_voxel_backward_gpu(
grad_input_points, grad_output_voxels, point_to_voxelidx, grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, reduce_count,
coor_to_voxelidx); convert_reduce_type(reduce_type));
return;
#else #else
AT_ERROR("Not compiled with GPU support"); TORCH_CHECK(false, "Not compiled with GPU support");
#endif #endif
} }
// return dynamic_point_to_voxel_cpu(points, TORCH_CHECK(false, "do not support cpu yet");
// voxel_mapping,
// voxel_size,
// coors_range);
} }
} // namespace voxelization } // namespace voxelization
import pytest
import torch
from torch.autograd import gradcheck
from mmdet3d.ops import DynamicScatter
def test_dynamic_scatter():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
coors = torch.randint(
low=-1, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
coors[coors.min(dim=-1).values < 0] = -1
dsmean = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], True)
dsmax = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], False)
ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
ref_voxel_feats_mean = []
ref_voxel_feats_max = []
for ref_voxel_coor in ref_voxel_coors:
voxel_mask = (coors == ref_voxel_coor).all(dim=-1)
ref_voxel_feats_mean.append(feats[voxel_mask].mean(dim=0))
ref_voxel_feats_max.append(feats[voxel_mask].max(dim=0).values)
ref_voxel_feats_mean = torch.stack(ref_voxel_feats_mean)
ref_voxel_feats_max = torch.stack(ref_voxel_feats_max)
feats_out_mean, coors_out_mean = dsmean(feats, coors)
seq_mean = (coors_out_mean[:, 0] * 400 + coors_out_mean[:, 1] * 20 +
coors_out_mean[:, 2]).argsort()
feats_out_mean = feats_out_mean[seq_mean]
coors_out_mean = coors_out_mean[seq_mean]
feats_out_max, coors_out_max = dsmax(feats, coors)
seq_max = (coors_out_max[:, 0] * 400 + coors_out_max[:, 1] * 20 +
coors_out_max[:, 2]).argsort()
feats_out_max = feats_out_max[seq_max]
coors_cout_max = coors_out_max[seq_max]
assert (coors_out_mean == ref_voxel_coors).all()
assert torch.allclose(
feats_out_mean, ref_voxel_feats_mean, atol=1e-2, rtol=1e-5)
assert (coors_cout_max == ref_voxel_coors).all()
assert torch.allclose(
feats_out_max, ref_voxel_feats_max, atol=1e-2, rtol=1e-5)
# test grad #
feats = torch.rand(
size=(100, 4), dtype=torch.float32, device='cuda') * 100 - 50
coors = torch.randint(
low=-1, high=3, size=(100, 3), dtype=torch.int32, device='cuda')
feats.requires_grad_()
gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
gradcheck(dsmax, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment