Unverified Commit 234bf975 authored by zhanggefan's avatar zhanggefan Committed by GitHub
Browse files

[Enhancement] faster but nondeterministic version of hard voxelization (#904)

parent 21e4b12c
...@@ -29,6 +29,13 @@ int hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels, ...@@ -29,6 +29,13 @@ int hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels,
const int max_points, const int max_voxels, const int max_points, const int max_voxels,
const int NDim = 3); const int NDim = 3);
int nondisterministic_hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3);
void dynamic_voxelize_gpu(const at::Tensor &points, at::Tensor &coors, void dynamic_voxelize_gpu(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
...@@ -53,12 +60,17 @@ inline int hard_voxelize(const at::Tensor &points, at::Tensor &voxels, ...@@ -53,12 +60,17 @@ inline int hard_voxelize(const at::Tensor &points, at::Tensor &voxels,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
const int max_points, const int max_voxels, const int max_points, const int max_voxels,
const int NDim = 3) { const int NDim = 3, const bool deterministic = true) {
if (points.device().is_cuda()) { if (points.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return hard_voxelize_gpu(points, voxels, coors, num_points_per_voxel, if (deterministic) {
voxel_size, coors_range, max_points, max_voxels, return hard_voxelize_gpu(points, voxels, coors, num_points_per_voxel,
NDim); voxel_size, coors_range, max_points, max_voxels,
NDim);
}
return nondisterministic_hard_voxelize_gpu(points, voxels, coors, num_points_per_voxel,
voxel_size, coors_range, max_points, max_voxels,
NDim);
#else #else
AT_ERROR("Not compiled with GPU support"); AT_ERROR("Not compiled with GPU support");
#endif #endif
......
...@@ -179,6 +179,53 @@ __global__ void determin_voxel_num( ...@@ -179,6 +179,53 @@ __global__ void determin_voxel_num(
} }
} }
__global__ void nondisterministic_get_assign_pos(
const int nthreads, const int32_t *coors_map, int32_t *pts_id,
int32_t *coors_count, int32_t *reduce_count, int32_t *coors_order) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
if (coors_idx > -1) {
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
pts_id[thread_idx] = coors_pts_pos;
if (coors_pts_pos == 0) {
coors_order[coors_idx] = atomicAdd(coors_count, 1);
}
}
}
}
template<typename T>
__global__ void nondisterministic_assign_point_voxel(
const int nthreads, const T *points, const int32_t *coors_map,
const int32_t *pts_id, const int32_t *coors_in,
const int32_t *reduce_count, const int32_t *coors_order,
T *voxels, int32_t *coors, int32_t *pts_count, const int max_voxels,
const int max_points, const int num_features, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
int coors_pts_pos = pts_id[thread_idx];
if (coors_idx > -1) {
int coors_pos = coors_order[coors_idx];
if (coors_pos < max_voxels && coors_pts_pos < max_points) {
auto voxels_offset =
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
auto points_offset = points + thread_idx * num_features;
for (int k = 0; k < num_features; k++) {
voxels_offset[k] = points_offset[k];
}
if (coors_pts_pos == 0) {
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
auto coors_offset = coors + coors_pos * NDim;
auto coors_in_offset = coors_in + coors_idx * NDim;
for (int k = 0; k < NDim; k++) {
coors_offset[k] = coors_in_offset[k];
}
}
}
}
}
}
namespace voxelization { namespace voxelization {
int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels, int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels,
...@@ -325,6 +372,116 @@ int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels, ...@@ -325,6 +372,116 @@ int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels,
return voxel_num_int; return voxel_num_int;
} }
int nondisterministic_hard_voxelize_gpu(
const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
CHECK_INPUT(points);
at::cuda::CUDAGuard device_guard(points.device());
const int num_points = points.size(0);
const int num_features = points.size(1);
if (num_points == 0)
return 0;
const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(torch::kInt32));
dim3 grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 block(512);
// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y,
voxel_z, coors_x_min, coors_y_min, coors_z_min, coors_x_max,
coors_y_max, coors_z_max, grid_x, grid_y, grid_z, num_points,
num_features, NDim);
}));
at::Tensor coors_map;
at::Tensor coors_count;
at::Tensor coors_order;
at::Tensor reduce_count;
at::Tensor pts_id;
auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);
std::tie(temp_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, false);
if (temp_coors.index({0, 0}).lt(0).item<bool>()) {
// the first element of temp_coors is (-1,-1,-1) and should be removed
temp_coors = temp_coors.slice(0, 1);
coors_map = coors_map - 1;
}
int num_coors = temp_coors.size(0);
temp_coors = temp_coors.to(torch::kInt32);
coors_map = coors_map.to(torch::kInt32);
coors_count = coors_map.new_zeros(1);
coors_order = coors_map.new_empty(num_coors);
reduce_count = coors_map.new_zeros(num_coors);
pts_id = coors_map.new_zeros(num_points);
dim3 cp_grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 cp_block(512);
AT_DISPATCH_ALL_TYPES(points.scalar_type(), "get_assign_pos", ([&] {
nondisterministic_get_assign_pos<<<cp_grid, cp_block, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_points,
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
coors_count.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>());
}));
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
nondisterministic_assign_point_voxel<scalar_t>
<<<cp_grid, cp_block, 0, at::cuda::getCurrentCUDAStream()>>>(
num_points, points.contiguous().data_ptr<scalar_t>(),
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
temp_coors.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>(),
voxels.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int32_t>(),
num_points_per_voxel.contiguous().data_ptr<int32_t>(),
max_voxels, max_points,
num_features, NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
return max_voxels < num_coors ? max_voxels : num_coors;
}
void dynamic_voxelize_gpu(const at::Tensor& points, at::Tensor& coors, void dynamic_voxelize_gpu(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const std::vector<float> coors_range,
......
...@@ -15,7 +15,8 @@ class _Voxelization(Function): ...@@ -15,7 +15,8 @@ class _Voxelization(Function):
voxel_size, voxel_size,
coors_range, coors_range,
max_points=35, max_points=35,
max_voxels=20000): max_voxels=20000,
deterministic=True):
"""convert kitti points(N, >=3) to voxels. """convert kitti points(N, >=3) to voxels.
Args: Args:
...@@ -30,6 +31,16 @@ class _Voxelization(Function): ...@@ -30,6 +31,16 @@ class _Voxelization(Function):
max_voxels: int. indicate maximum voxels this function create. max_voxels: int. indicate maximum voxels this function create.
for second, 20000 is a good choice. Users should shuffle points for second, 20000 is a good choice. Users should shuffle points
before call this function because max_voxels may drop points. before call this function because max_voxels may drop points.
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
Returns: Returns:
voxels: [M, max_points, ndim] float tensor. only contain points voxels: [M, max_points, ndim] float tensor. only contain points
...@@ -50,7 +61,8 @@ class _Voxelization(Function): ...@@ -50,7 +61,8 @@ class _Voxelization(Function):
size=(max_voxels, ), dtype=torch.int) size=(max_voxels, ), dtype=torch.int)
voxel_num = hard_voxelize(points, voxels, coors, voxel_num = hard_voxelize(points, voxels, coors,
num_points_per_voxel, voxel_size, num_points_per_voxel, voxel_size,
coors_range, max_points, max_voxels, 3) coors_range, max_points, max_voxels, 3,
deterministic)
# select the valid voxels # select the valid voxels
voxels_out = voxels[:voxel_num] voxels_out = voxels[:voxel_num]
coors_out = coors[:voxel_num] coors_out = coors[:voxel_num]
...@@ -67,7 +79,8 @@ class Voxelization(nn.Module): ...@@ -67,7 +79,8 @@ class Voxelization(nn.Module):
voxel_size, voxel_size,
point_cloud_range, point_cloud_range,
max_num_points, max_num_points,
max_voxels=20000): max_voxels=20000,
deterministic=True):
super(Voxelization, self).__init__() super(Voxelization, self).__init__()
""" """
Args: Args:
...@@ -77,6 +90,16 @@ class Voxelization(nn.Module): ...@@ -77,6 +90,16 @@ class Voxelization(nn.Module):
max_num_points (int): max number of points per voxel max_num_points (int): max number of points per voxel
max_voxels (tuple or int): max number of voxels in max_voxels (tuple or int): max number of voxels in
(training, testing) time (training, testing) time
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
""" """
self.voxel_size = voxel_size self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range self.point_cloud_range = point_cloud_range
...@@ -85,6 +108,7 @@ class Voxelization(nn.Module): ...@@ -85,6 +108,7 @@ class Voxelization(nn.Module):
self.max_voxels = max_voxels self.max_voxels = max_voxels
else: else:
self.max_voxels = _pair(max_voxels) self.max_voxels = _pair(max_voxels)
self.deterministic = deterministic
point_cloud_range = torch.tensor( point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32) point_cloud_range, dtype=torch.float32)
...@@ -110,7 +134,8 @@ class Voxelization(nn.Module): ...@@ -110,7 +134,8 @@ class Voxelization(nn.Module):
max_voxels = self.max_voxels[1] max_voxels = self.max_voxels[1]
return voxelization(input, self.voxel_size, self.point_cloud_range, return voxelization(input, self.voxel_size, self.point_cloud_range,
self.max_num_points, max_voxels) self.max_num_points, max_voxels,
self.deterministic)
def __repr__(self): def __repr__(self):
tmpstr = self.__class__.__name__ + '(' tmpstr = self.__class__.__name__ + '('
...@@ -118,5 +143,6 @@ class Voxelization(nn.Module): ...@@ -118,5 +143,6 @@ class Voxelization(nn.Module):
tmpstr += ', point_cloud_range=' + str(self.point_cloud_range) tmpstr += ', point_cloud_range=' + str(self.point_cloud_range)
tmpstr += ', max_num_points=' + str(self.max_num_points) tmpstr += ', max_num_points=' + str(self.max_num_points)
tmpstr += ', max_voxels=' + str(self.max_voxels) tmpstr += ', max_voxels=' + str(self.max_voxels)
tmpstr += ', deterministic=' + str(self.deterministic)
tmpstr += ')' tmpstr += ')'
return tmpstr return tmpstr
...@@ -82,3 +82,84 @@ def test_voxelization(): ...@@ -82,3 +82,84 @@ def test_voxelization():
assert np.all( assert np.all(
points[indices] == expected_coors[i][:num_points_current_voxel]) points[indices] == expected_coors[i][:num_points_current_voxel])
assert num_points_current_voxel == expected_num_points_per_voxel[i] assert num_points_current_voxel == expected_num_points_per_voxel[i]
def test_voxelization_nondeterministic():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
voxel_size = [0.5, 0.5, 0.5]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
data_path = './tests/data/kitti/training/velodyne_reduced/000000.bin'
load_points_from_file = LoadPointsFromFile(
coord_type='LIDAR', load_dim=4, use_dim=4)
results = dict()
results['pts_filename'] = data_path
results = load_points_from_file(results)
points = results['points'].tensor.numpy()
points = torch.tensor(points)
max_num_points = -1
dynamic_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)
max_num_points = 10
max_voxels = 50
hard_voxelization = Voxelization(
voxel_size,
point_cloud_range,
max_num_points,
max_voxels,
deterministic=False)
# test hard_voxelization (non-deterministic version) on gpu
points = torch.tensor(points).contiguous().to(device='cuda:0')
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy().tolist()
voxels = voxels.cpu().detach().numpy().tolist()
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy().tolist()
coors_all = dynamic_voxelization.forward(points)
coors_all = coors_all.cpu().detach().numpy().tolist()
coors_set = set([tuple(c) for c in coors])
coors_all_set = set([tuple(c) for c in coors_all])
assert len(coors_set) == len(coors)
assert len(coors_set - coors_all_set) == 0
points = points.cpu().detach().numpy().tolist()
coors_points_dict = {}
for c, ps in zip(coors_all, points):
if tuple(c) not in coors_points_dict:
coors_points_dict[tuple(c)] = set()
coors_points_dict[tuple(c)].add(tuple(ps))
for c, ps, n in zip(coors, voxels, num_points_per_voxel):
ideal_voxel_points_set = coors_points_dict[tuple(c)]
voxel_points_set = set([tuple(p) for p in ps[:n]])
assert len(voxel_points_set) == n
if n < max_num_points:
assert voxel_points_set == ideal_voxel_points_set
for p in ps[n:]:
assert max(p) == min(p) == 0
else:
assert len(voxel_points_set - ideal_voxel_points_set) == 0
# test hard_voxelization (non-deterministic version) on gpu
# with all input point in range
points = torch.tensor(points).contiguous().to(device='cuda:0')[:max_voxels]
coors_all = dynamic_voxelization.forward(points)
valid_mask = coors_all.ge(0).all(-1)
points = points[valid_mask]
coors_all = coors_all[valid_mask]
coors_all = coors_all.cpu().detach().numpy().tolist()
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy().tolist()
coors_set = set([tuple(c) for c in coors])
coors_all_set = set([tuple(c) for c in coors_all])
assert len(coors_set) == len(coors) == len(coors_all_set)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment