Commit b9d4e636 authored by wangchao1's avatar wangchao1
Browse files

优化voexliztion算子

parent a7eb7dbe
...@@ -53,45 +53,44 @@ __global__ void dynamic_voxelize_kernel( ...@@ -53,45 +53,44 @@ __global__ void dynamic_voxelize_kernel(
template <typename T, typename T_int> template <typename T, typename T_int>
__global__ void __launch_bounds__(1024) dynamic_voxelize_kernel_fast( __global__ void __launch_bounds__(1024) dynamic_voxelize_kernel_fast(
const T* points, T_int* coors,int64_t* coors64, const float voxel_x, const float voxel_y, const T* points, T_int* coors,T_int* coors64, const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min, const float coors_y_min, const float voxel_z, const float coors_x_min, const float coors_y_min,
const float coors_z_min, const float coors_x_max, const float coors_y_max, const float coors_z_min, const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int grid_x, const int grid_y, const float coors_z_max, const int grid_x, const int grid_y,
const int grid_z, const int num_points, const int num_features, const int grid_z, const int num_points, const int num_features,
const int NDim) { const int NDim) {
CUDA_1D_KERNEL_LOOP(index, num_points) { CUDA_1D_KERNEL_LOOP(index, num_points) {
// To save some computation auto points_offset = points + index * num_features;
auto points_offset = points + index * num_features; auto coors_offset = coors + index * NDim;
auto coors_offset = coors + index * NDim; int c_x = floor((points_offset[0] - coors_x_min) / voxel_x);
int c_x = floorf((points_offset[0] - coors_x_min) / voxel_x); if (c_x < 0 || c_x >= grid_x) {
if (c_x < 0 || c_x >= grid_x) { coors_offset[0] = -1;
coors_offset[0] = -1; coors64[index] = -1;
coors64[index] = -1; return;
continue; }
}
int c_y = floorf((points_offset[1] - coors_y_min) / voxel_y); int c_y = floor((points_offset[1] - coors_y_min) / voxel_y);
if (c_y < 0 || c_y >= grid_y) { if (c_y < 0 || c_y >= grid_y) {
coors_offset[0] = -1; coors_offset[0] = -1;
coors_offset[1] = -1; coors_offset[1] = -1;
coors64[index] = -1; coors64[index] = -1;
continue; return;
} }
int c_z = floorf((points_offset[2] - coors_z_min) / voxel_z); int c_z = floor((points_offset[2] - coors_z_min) / voxel_z);
if (c_z < 0 || c_z >= grid_z) { if (c_z < 0 || c_z >= grid_z) {
coors_offset[0] = -1; coors_offset[0] = -1;
coors_offset[1] = -1; coors_offset[1] = -1;
coors_offset[2] = -1; coors_offset[2] = -1;
coors64[index] = -1; coors64[index] = -1;
} else { } else {
coors_offset[0] = c_z; coors_offset[0] = c_x;
coors_offset[1] = c_y; coors_offset[1] = c_y;
coors_offset[2] = c_x; coors_offset[2] = c_z;
coors64[index] = ((int64_t)c_x)*(grid_z*grid_y) +((int64_t)c_y)*(grid_z)+c_z; coors64[index] = (c_x)*(grid_z*grid_y) +(c_y)*(grid_z)+c_z;
}
} }
} }
}
...@@ -186,13 +185,13 @@ __global__ void point_to_voxelidx_kernel(const T_int* coor, ...@@ -186,13 +185,13 @@ __global__ void point_to_voxelidx_kernel(const T_int* coor,
template <typename T_int,int VEC> template <typename T_int,int VEC>
__global__ void __launch_bounds__(1024) point_to_voxelidx_kernel_fast(const int64_t* coor, __global__ void __launch_bounds__(1024) point_to_voxelidx_kernel_fast(const int* coor,
T_int* point_to_voxelidx, T_int* point_to_voxelidx,
T_int* point_to_pointidx, T_int* point_to_pointidx,
const int max_points, const int max_points,
const int max_voxels, const int max_voxels,
const int num_points) { const int num_points) {
using Int64VEC = __attribute__( (__vector_size__(VEC * sizeof(int64_t)) )) int64_t; using Int64VEC = __attribute__( (__vector_size__(VEC * sizeof(int)) )) int;
int tid=threadIdx.x; int tid=threadIdx.x;
int index=(blockIdx.x * blockDim.x + tid)*VEC; int index=(blockIdx.x * blockDim.x + tid)*VEC;
auto i_coor = *reinterpret_cast<const Int64VEC*>(coor+index); auto i_coor = *reinterpret_cast<const Int64VEC*>(coor+index);
......
...@@ -37,7 +37,7 @@ int HardVoxelizeForwardCUDAKernelLauncher( ...@@ -37,7 +37,7 @@ int HardVoxelizeForwardCUDAKernelLauncher(
at::Tensor temp_coors = at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(at::kInt)); at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
at::Tensor temp_coors64 = at::Tensor temp_coors64 =
at::zeros({num_points}, points.options().dtype(at::kLong)); at::zeros({num_points}, points.options().dtype(at::kInt));
dim3 grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096)); dim3 grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 block(512); dim3 block(512);
...@@ -48,7 +48,7 @@ int HardVoxelizeForwardCUDAKernelLauncher( ...@@ -48,7 +48,7 @@ int HardVoxelizeForwardCUDAKernelLauncher(
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(), points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), temp_coors.contiguous().data_ptr<int>(),
temp_coors64.data_ptr<int64_t>(), voxel_x, voxel_y, temp_coors64.data_ptr<int>(), voxel_x, voxel_y,
voxel_z, coors_x_min, coors_y_min, coors_z_min, coors_x_max, voxel_z, coors_x_min, coors_y_min, coors_z_min, coors_x_max,
coors_y_max, coors_z_max, grid_x, grid_y, grid_z, num_points, coors_y_max, coors_z_max, grid_x, grid_y, grid_z, num_points,
num_features, NDim); num_features, NDim);
...@@ -70,14 +70,14 @@ int HardVoxelizeForwardCUDAKernelLauncher( ...@@ -70,14 +70,14 @@ int HardVoxelizeForwardCUDAKernelLauncher(
points.options().dtype(at::kInt)); points.options().dtype(at::kInt));
int blocksize=256; int blocksize=256;
constexpr int VEC=2; constexpr int VEC=4;
dim3 map_grid(at::cuda::ATenCeilDiv(num_points, blocksize*VEC)); dim3 map_grid(at::cuda::ATenCeilDiv(num_points, blocksize*VEC));
dim3 map_block(blocksize); dim3 map_block(blocksize);
AT_DISPATCH_ALL_TYPES( AT_DISPATCH_ALL_TYPES(
temp_coors.scalar_type(), "determin_duplicate", ([&] { temp_coors.scalar_type(), "determin_duplicate", ([&] {
point_to_voxelidx_kernel_fast<int,VEC> point_to_voxelidx_kernel_fast<int,VEC>
<<<map_grid, map_block, 0, stream>>>( <<<map_grid, map_block, 0, stream>>>(
temp_coors64.contiguous().data_ptr<int64_t>(), temp_coors64.contiguous().data_ptr<int>(),
point_to_voxelidx.contiguous().data_ptr<int>(), point_to_voxelidx.contiguous().data_ptr<int>(),
point_to_pointidx.contiguous().data_ptr<int>(), max_points, point_to_pointidx.contiguous().data_ptr<int>(), max_points,
max_voxels, num_points); max_voxels, num_points);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment