Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
a7eb7dbe
Commit
a7eb7dbe
authored
Aug 27, 2025
by
wangchao1
Browse files
update voxelization
parent
7374ed8c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
188 additions
and
39 deletions
+188
-39
mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh
mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh
+135
-2
mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu
mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu
+53
-37
No files found.
mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh
View file @
a7eb7dbe
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
typedef
enum
{
SUM
=
0
,
MEAN
=
1
,
MAX
=
2
}
reduce_t
;
typedef
enum
{
SUM
=
0
,
MEAN
=
1
,
MAX
=
2
}
reduce_t
;
template
<
typename
T
,
typename
T_int
>
template
<
typename
T
,
typename
T_int
>
__global__
void
dynamic_voxelize_kernel
(
__global__
void
dynamic_voxelize_kernel
(
const
T
*
points
,
T_int
*
coors
,
const
float
voxel_x
,
const
float
voxel_y
,
const
T
*
points
,
T_int
*
coors
,
const
float
voxel_x
,
const
float
voxel_y
,
...
@@ -49,8 +50,55 @@ __global__ void dynamic_voxelize_kernel(
...
@@ -49,8 +50,55 @@ __global__ void dynamic_voxelize_kernel(
}
}
}
}
template
<
typename
T
,
typename
T_int
>
__global__
void
__launch_bounds__
(
1024
)
dynamic_voxelize_kernel_fast
(
const
T
*
points
,
T_int
*
coors
,
int64_t
*
coors64
,
const
float
voxel_x
,
const
float
voxel_y
,
const
float
voxel_z
,
const
float
coors_x_min
,
const
float
coors_y_min
,
const
float
coors_z_min
,
const
float
coors_x_max
,
const
float
coors_y_max
,
const
float
coors_z_max
,
const
int
grid_x
,
const
int
grid_y
,
const
int
grid_z
,
const
int
num_points
,
const
int
num_features
,
const
int
NDim
)
{
CUDA_1D_KERNEL_LOOP
(
index
,
num_points
)
{
// To save some computation
auto
points_offset
=
points
+
index
*
num_features
;
auto
coors_offset
=
coors
+
index
*
NDim
;
int
c_x
=
floorf
((
points_offset
[
0
]
-
coors_x_min
)
/
voxel_x
);
if
(
c_x
<
0
||
c_x
>=
grid_x
)
{
coors_offset
[
0
]
=
-
1
;
coors64
[
index
]
=
-
1
;
continue
;
}
int
c_y
=
floorf
((
points_offset
[
1
]
-
coors_y_min
)
/
voxel_y
);
if
(
c_y
<
0
||
c_y
>=
grid_y
)
{
coors_offset
[
0
]
=
-
1
;
coors_offset
[
1
]
=
-
1
;
coors64
[
index
]
=
-
1
;
continue
;
}
int
c_z
=
floorf
((
points_offset
[
2
]
-
coors_z_min
)
/
voxel_z
);
if
(
c_z
<
0
||
c_z
>=
grid_z
)
{
coors_offset
[
0
]
=
-
1
;
coors_offset
[
1
]
=
-
1
;
coors_offset
[
2
]
=
-
1
;
coors64
[
index
]
=
-
1
;
}
else
{
coors_offset
[
0
]
=
c_z
;
coors_offset
[
1
]
=
c_y
;
coors_offset
[
2
]
=
c_x
;
coors64
[
index
]
=
((
int64_t
)
c_x
)
*
(
grid_z
*
grid_y
)
+
((
int64_t
)
c_y
)
*
(
grid_z
)
+
c_z
;
}
}
}
template
<
typename
T
,
typename
T_int
>
template
<
typename
T
,
typename
T_int
>
__global__
void
assign_point_to_voxel
(
const
int
nthreads
,
const
T
*
points
,
__global__
void
__launch_bounds__
(
1024
)
assign_point_to_voxel
(
const
int
nthreads
,
const
T
*
points
,
T_int
*
point_to_voxelidx
,
T_int
*
point_to_voxelidx
,
T_int
*
coor_to_voxelidx
,
T
*
voxels
,
T_int
*
coor_to_voxelidx
,
T
*
voxels
,
const
int
max_points
,
const
int
max_points
,
...
@@ -73,7 +121,7 @@ __global__ void assign_point_to_voxel(const int nthreads, const T* points,
...
@@ -73,7 +121,7 @@ __global__ void assign_point_to_voxel(const int nthreads, const T* points,
}
}
template
<
typename
T
,
typename
T_int
>
template
<
typename
T
,
typename
T_int
>
__global__
void
assign_voxel_coors
(
const
int
nthreads
,
T_int
*
coor
,
__global__
void
__launch_bounds__
(
1024
)
assign_voxel_coors
(
const
int
nthreads
,
T_int
*
coor
,
T_int
*
point_to_voxelidx
,
T_int
*
point_to_voxelidx
,
T_int
*
coor_to_voxelidx
,
T_int
*
voxel_coors
,
T_int
*
coor_to_voxelidx
,
T_int
*
voxel_coors
,
const
int
num_points
,
const
int
NDim
)
{
const
int
num_points
,
const
int
NDim
)
{
...
@@ -135,6 +183,66 @@ __global__ void point_to_voxelidx_kernel(const T_int* coor,
...
@@ -135,6 +183,66 @@ __global__ void point_to_voxelidx_kernel(const T_int* coor,
}
}
}
}
template
<
typename
T_int
,
int
VEC
>
__global__
void
__launch_bounds__
(
1024
)
point_to_voxelidx_kernel_fast
(
const
int64_t
*
coor
,
T_int
*
point_to_voxelidx
,
T_int
*
point_to_pointidx
,
const
int
max_points
,
const
int
max_voxels
,
const
int
num_points
)
{
using
Int64VEC
=
__attribute__
(
(
__vector_size__
(
VEC
*
sizeof
(
int64_t
))
))
int64_t
;
int
tid
=
threadIdx
.
x
;
int
index
=
(
blockIdx
.
x
*
blockDim
.
x
+
tid
)
*
VEC
;
auto
i_coor
=
*
reinterpret_cast
<
const
Int64VEC
*>
(
coor
+
index
);
if
(
index
>=
num_points
)
return
;
int
num
[
VEC
];
for
(
int
i
=
0
;
i
<
VEC
;
i
++
){
num
[
i
]
=
0
;
}
for
(
int
i
=
0
;
i
<
index
;
i
+=
VEC
)
{
auto
prev_coor
=
*
reinterpret_cast
<
const
Int64VEC
*>
(
coor
+
i
);
for
(
int
k
=
0
;
k
<
VEC
;
k
++
){
if
(
prev_coor
[
k
]
==
-
1
)
continue
;
for
(
int
m
=
0
;
m
<
VEC
;
m
++
){
if
(
prev_coor
[
k
]
==
i_coor
[
m
])
{
num
[
m
]
++
;
if
(
num
[
m
]
==
1
)
{
point_to_pointidx
[
index
+
m
]
=
i
+
k
;
}
}
}
}
}
{
for
(
int
k
=
0
;
k
<
VEC
-
1
;
k
++
){
if
(
i_coor
[
k
]
==
-
1
)
continue
;
for
(
int
m
=
k
+
1
;
m
<
VEC
;
m
++
){
if
(
i_coor
[
k
]
==
i_coor
[
m
])
{
num
[
m
]
++
;
if
(
num
[
m
]
==
1
)
{
point_to_pointidx
[
index
+
m
]
=
index
+
k
;
}
}
}
}
}
for
(
int
k
=
0
;
k
<
VEC
;
k
++
){
if
(
i_coor
[
k
]
==-
1
)
continue
;
if
(
num
[
k
]
==
0
)
{
point_to_pointidx
[
index
+
k
]
=
index
+
k
;
}
if
(
num
[
k
]
<
max_points
)
{
point_to_voxelidx
[
index
+
k
]
=
num
[
k
];
}
}
}
template
<
typename
T_int
>
template
<
typename
T_int
>
__global__
void
determin_voxel_num
(
__global__
void
determin_voxel_num
(
// const T_int* coor,
// const T_int* coor,
...
@@ -166,6 +274,31 @@ __global__ void determin_voxel_num(
...
@@ -166,6 +274,31 @@ __global__ void determin_voxel_num(
}
}
}
}
template
<
typename
T_int
>
void
determin_voxel_num_cpu
(
T_int
*
num_points_per_voxel
,
T_int
*
point_to_voxelidx
,
T_int
*
point_to_pointidx
,
T_int
*
coor_to_voxelidx
,
T_int
*
voxel_num
,
const
int
max_points
,
const
int
max_voxels
,
const
int
num_points
)
{
for
(
int
i
=
0
;
i
<
num_points
;
++
i
)
{
int
point_pos_in_voxel
=
point_to_voxelidx
[
i
];
if
(
point_pos_in_voxel
==
-
1
)
{
continue
;
}
else
if
(
point_pos_in_voxel
==
0
)
{
int
voxelidx
=
voxel_num
[
0
];
if
(
voxel_num
[
0
]
>=
max_voxels
)
continue
;
voxel_num
[
0
]
+=
1
;
coor_to_voxelidx
[
i
]
=
voxelidx
;
num_points_per_voxel
[
voxelidx
]
=
1
;
}
else
{
int
point_idx
=
point_to_pointidx
[
i
];
int
voxelidx
=
coor_to_voxelidx
[
point_idx
];
if
(
voxelidx
!=
-
1
)
{
coor_to_voxelidx
[
i
]
=
voxelidx
;
num_points_per_voxel
[
voxelidx
]
+=
1
;
}
}
}
}
__global__
void
nondeterministic_get_assign_pos
(
__global__
void
nondeterministic_get_assign_pos
(
const
int
nthreads
,
const
int32_t
*
coors_map
,
int32_t
*
pts_id
,
const
int
nthreads
,
const
int32_t
*
coors_map
,
int32_t
*
pts_id
,
int32_t
*
coors_count
,
int32_t
*
reduce_count
,
int32_t
*
coors_order
)
{
int32_t
*
coors_count
,
int32_t
*
reduce_count
,
int32_t
*
coors_order
)
{
...
...
mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu
View file @
a7eb7dbe
...
@@ -12,7 +12,7 @@ int HardVoxelizeForwardCUDAKernelLauncher(
...
@@ -12,7 +12,7 @@ int HardVoxelizeForwardCUDAKernelLauncher(
const
int
max_voxels
,
const
int
NDim
=
3
)
{
const
int
max_voxels
,
const
int
NDim
=
3
)
{
// current version tooks about 0.04s for one frame on cpu
// current version tooks about 0.04s for one frame on cpu
// check device
// check device
at
::
cuda
::
CUDAGuard
device_guard
(
points
.
device
());
at
::
cuda
::
CUDAGuard
device_guard
(
points
.
device
());
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
@@ -36,19 +36,22 @@ int HardVoxelizeForwardCUDAKernelLauncher(
...
@@ -36,19 +36,22 @@ int HardVoxelizeForwardCUDAKernelLauncher(
// map points to voxel coors
// map points to voxel coors
at
::
Tensor
temp_coors
=
at
::
Tensor
temp_coors
=
at
::
zeros
({
num_points
,
NDim
},
points
.
options
().
dtype
(
at
::
kInt
));
at
::
zeros
({
num_points
,
NDim
},
points
.
options
().
dtype
(
at
::
kInt
));
at
::
Tensor
temp_coors64
=
at
::
zeros
({
num_points
},
points
.
options
().
dtype
(
at
::
kLong
));
dim3
grid
(
std
::
min
(
at
::
cuda
::
ATenCeilDiv
(
num_points
,
512
),
4096
));
dim3
grid
(
std
::
min
(
at
::
cuda
::
ATenCeilDiv
(
num_points
,
512
),
4096
));
dim3
block
(
512
);
dim3
block
(
512
);
// 1. link point to corresponding voxel coors
// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES
(
AT_DISPATCH_ALL_TYPES
(
points
.
scalar_type
(),
"hard_voxelize_kernel"
,
([
&
]
{
points
.
scalar_type
(),
"hard_voxelize_kernel"
,
([
&
]
{
dynamic_voxelize_kernel
<
scalar_t
,
int
><<<
grid
,
block
,
0
,
stream
>>>
(
dynamic_voxelize_kernel_fast
<
scalar_t
,
int
>
points
.
contiguous
().
data_ptr
<
scalar_t
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
temp_coors
.
contiguous
().
data_ptr
<
int
>
(),
voxel_x
,
voxel_y
,
voxel_z
,
points
.
contiguous
().
data_ptr
<
scalar_t
>
(),
coors_x_min
,
coors_y_min
,
coors_z_min
,
coors_x_max
,
coors_y_max
,
temp_coors
.
contiguous
().
data_ptr
<
int
>
(),
coors_z_max
,
grid_x
,
grid_y
,
grid_z
,
num_points
,
num_features
,
temp_coors64
.
data_ptr
<
int64_t
>
(),
voxel_x
,
voxel_y
,
NDim
);
voxel_z
,
coors_x_min
,
coors_y_min
,
coors_z_min
,
coors_x_max
,
coors_y_max
,
coors_z_max
,
grid_x
,
grid_y
,
grid_z
,
num_points
,
num_features
,
NDim
);
}));
}));
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
...
@@ -66,45 +69,58 @@ int HardVoxelizeForwardCUDAKernelLauncher(
...
@@ -66,45 +69,58 @@ int HardVoxelizeForwardCUDAKernelLauncher(
},
},
points
.
options
().
dtype
(
at
::
kInt
));
points
.
options
().
dtype
(
at
::
kInt
));
dim3
map_grid
(
std
::
min
(
at
::
cuda
::
ATenCeilDiv
(
num_points
,
512
),
4096
));
int
blocksize
=
256
;
dim3
map_block
(
512
);
constexpr
int
VEC
=
2
;
dim3
map_grid
(
at
::
cuda
::
ATenCeilDiv
(
num_points
,
blocksize
*
VEC
));
dim3
map_block
(
blocksize
);
AT_DISPATCH_ALL_TYPES
(
AT_DISPATCH_ALL_TYPES
(
temp_coors
.
scalar_type
(),
"determin_duplicate"
,
([
&
]
{
temp_coors
.
scalar_type
(),
"determin_duplicate"
,
([
&
]
{
point_to_voxelidx_kernel
<
int
><<<
map_grid
,
map_block
,
0
,
stream
>>>
(
point_to_voxelidx_kernel_fast
<
int
,
VEC
>
temp_coors
.
contiguous
().
data_ptr
<
int
>
(),
<<<
map_grid
,
map_block
,
0
,
stream
>>>
(
point_to_voxelidx
.
contiguous
().
data_ptr
<
int
>
(),
temp_coors64
.
contiguous
().
data_ptr
<
int64_t
>
(),
point_to_pointidx
.
contiguous
().
data_ptr
<
int
>
(),
max_points
,
point_to_voxelidx
.
contiguous
().
data_ptr
<
int
>
(),
max_voxels
,
num_points
,
NDim
);
point_to_pointidx
.
contiguous
().
data_ptr
<
int
>
(),
max_points
,
max_voxels
,
num_points
);
}));
}));
cudaDeviceSynchronize
();
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
// 3. determine voxel num and voxel's coor index
// 3. determine voxel num and voxel's coor index
// make the logic in the CUDA device could accelerate about 10 times
// make the logic in the CUDA device could accelerate about 10 times
auto
coor_to_voxelidx
=
-
at
::
ones
(
{
auto
coor_to_voxelidx_cpu
=
-
at
::
ones
(
num_points
,
{
},
num_points
,
points
.
options
().
dtype
(
at
::
kInt
));
},
auto
voxel_num
=
at
::
zeros
(
points
.
options
().
device
(
at
::
kCPU
).
dtype
(
at
::
kInt
));
{
auto
voxel_num_cpu
=
at
::
zeros
(
{
1
,
1
,
},
},
points
.
options
().
dtype
(
at
::
kInt
));
// must be zero from the beginning
points
.
options
().
device
(
at
::
kCPU
).
dtype
(
at
::
kInt
));
AT_DISPATCH_ALL_TYPES
(
temp_coors
.
scalar_type
(),
"determin_duplicate"
,
([
&
]
{
auto
point_to_voxelidx_cpu
=
point_to_voxelidx
.
to
(
at
::
kCPU
);
determin_voxel_num
<
int
><<<
1
,
1
,
0
,
stream
>>>
(
auto
point_to_pointidx_cpu
=
point_to_pointidx
.
to
(
at
::
kCPU
);
num_points_per_voxel
.
contiguous
().
data_ptr
<
int
>
(),
auto
num_points_per_voxel_cpu
=
num_points_per_voxel
.
to
(
at
::
kCPU
);
point_to_voxelidx
.
contiguous
().
data_ptr
<
int
>
(),
point_to_pointidx
.
contiguous
().
data_ptr
<
int
>
(),
coor_to_voxelidx
.
contiguous
().
data_ptr
<
int
>
(),
voxel_num
.
contiguous
().
data_ptr
<
int
>
(),
max_points
,
max_voxels
,
num_points
);
}));
AT_DISPATCH_ALL_TYPES
(
temp_coors
.
scalar_type
(),
"determin_duplicate"
,
([
&
]
{
determin_voxel_num_cpu
<
int
>
(
num_points_per_voxel_cpu
.
contiguous
().
data_ptr
<
int
>
(),
point_to_voxelidx_cpu
.
contiguous
().
data_ptr
<
int
>
(),
point_to_pointidx_cpu
.
contiguous
().
data_ptr
<
int
>
(),
coor_to_voxelidx_cpu
.
contiguous
().
data_ptr
<
int
>
(),
voxel_num_cpu
.
contiguous
().
data_ptr
<
int
>
(),
max_points
,
max_voxels
,
num_points
);
}));
cudaDeviceSynchronize
();
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
auto
coor_to_voxelidx
=
coor_to_voxelidx_cpu
.
to
(
at
::
kCUDA
);
num_points_per_voxel
.
copy_
(
num_points_per_voxel_cpu
);
// 4. copy point features to voxels
// 4. copy point features to voxels
// Step 4 & 5 could be parallel
// Step 4 & 5 could be parallel
auto
pts_output_size
=
num_points
*
num_features
;
auto
pts_output_size
=
num_points
*
num_features
;
...
@@ -139,7 +155,7 @@ int HardVoxelizeForwardCUDAKernelLauncher(
...
@@ -139,7 +155,7 @@ int HardVoxelizeForwardCUDAKernelLauncher(
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
auto
voxel_num_cpu
=
voxel_num
.
to
(
at
::
kCPU
);
//
auto voxel_num_cpu = voxel_num.to(at::kCPU);
int
voxel_num_int
=
voxel_num_cpu
.
data_ptr
<
int
>
()[
0
];
int
voxel_num_int
=
voxel_num_cpu
.
data_ptr
<
int
>
()[
0
];
return
voxel_num_int
;
return
voxel_num_int
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment