Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
f3dfc413
Unverified
Commit
f3dfc413
authored
Sep 23, 2021
by
dingchang
Committed by
GitHub
Sep 23, 2021
Browse files
[Feature] Add ballquery op from mmdet3d (#1332)
parent
4e101e0b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
252 additions
and
7 deletions
+252
-7
docs/understand_mmcv/ops.md
docs/understand_mmcv/ops.md
+1
-0
docs_zh_CN/understand_mmcv/ops.md
docs_zh_CN/understand_mmcv/ops.md
+1
-0
mmcv/ops/__init__.py
mmcv/ops/__init__.py
+2
-1
mmcv/ops/ball_query.py
mmcv/ops/ball_query.py
+46
-0
mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh
mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh
+59
-0
mmcv/ops/csrc/pytorch/ball_query.cpp
mmcv/ops/csrc/pytorch/ball_query.cpp
+37
-0
mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu
mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu
+38
-0
mmcv/ops/csrc/pytorch/pybind.cpp
mmcv/ops/csrc/pytorch/pybind.cpp
+14
-6
tests/test_ops/test_ball_query.py
tests/test_ops/test_ball_query.py
+54
-0
No files found.
docs/understand_mmcv/ops.md
View file @
f3dfc413
...
...
@@ -2,6 +2,7 @@
We implement common CUDA ops used in detection, segmentation, etc.
-
BallQuery
-
BBoxOverlaps
-
CARAFE
-
CrissCrossAttention
...
...
docs_zh_CN/understand_mmcv/ops.md
View file @
f3dfc413
...
...
@@ -2,6 +2,7 @@
MMCV 提供了检测、分割等任务中常用的 CUDA 算子
-
BallQuery
-
BBoxOverlaps
-
CARAFE
-
CrissCrossAttention
...
...
mmcv/ops/__init__.py
View file @
f3dfc413
# Copyright (c) OpenMMLab. All rights reserved.
from
.ball_query
import
ball_query
from
.bbox
import
bbox_overlaps
from
.border_align
import
BorderAlign
,
border_align
from
.box_iou_rotated
import
box_iou_rotated
...
...
@@ -50,7 +51,7 @@ __all__ = [
'ConvTranspose2d'
,
'Linear'
,
'MaxPool2d'
,
'CrissCrossAttention'
,
'PSAMask'
,
'point_sample'
,
'rel_roi_point_to_rel_img_point'
,
'SimpleRoIAlign'
,
'SAConv2d'
,
'TINShift'
,
'tin_shift'
,
'box_iou_rotated'
,
'nms_rotated'
,
'upfirdn2d'
,
'FusedBiasLeakyReLU'
,
'fused_bias_leakyrelu'
,
'ball_query'
,
'upfirdn2d'
,
'FusedBiasLeakyReLU'
,
'fused_bias_leakyrelu'
,
'RoIAlignRotated'
,
'roi_align_rotated'
,
'pixel_group'
,
'contour_expand'
,
'MultiScaleDeformableAttention'
,
'BorderAlign'
,
'border_align'
]
mmcv/ops/ball_query.py
0 → 100644
View file @
f3dfc413
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
torch.autograd
import
Function
from
..utils
import
ext_loader
ext_module
=
ext_loader
.
load_ext
(
'_ext'
,
[
'ball_query_forward'
])
class
BallQuery
(
Function
):
"""Find nearby points in spherical space."""
@
staticmethod
def
forward
(
ctx
,
min_radius
:
float
,
max_radius
:
float
,
sample_num
:
int
,
xyz
:
torch
.
Tensor
,
center_xyz
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
Returns:
Tensor: (B, npoint, nsample) tensor with the indicies of
the features that form the query balls.
"""
assert
center_xyz
.
is_contiguous
()
assert
xyz
.
is_contiguous
()
assert
min_radius
<
max_radius
B
,
N
,
_
=
xyz
.
size
()
npoint
=
center_xyz
.
size
(
1
)
idx
=
xyz
.
new_zeros
(
B
,
npoint
,
sample_num
,
dtype
=
torch
.
int
)
ext_module
.
ball_query_forward
(
B
,
N
,
npoint
,
min_radius
,
max_radius
,
sample_num
,
center_xyz
,
xyz
,
idx
)
ctx
.
mark_non_differentiable
(
idx
)
return
idx
@
staticmethod
def
backward
(
ctx
,
a
=
None
):
return
None
,
None
,
None
,
None
ball_query
=
BallQuery
.
apply
mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh
0 → 100644
View file @
f3dfc413
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#ifndef BALL_QUERY_CUDA_KERNEL_CUH
#define BALL_QUERY_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
template
<
typename
T
>
__global__
void
ball_query_forward_cuda_kernel
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
T
*
new_xyz
,
const
T
*
xyz
,
int
*
idx
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int
bs_idx
=
blockIdx
.
y
;
int
pt_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
bs_idx
>=
b
||
pt_idx
>=
m
)
return
;
new_xyz
+=
bs_idx
*
m
*
3
+
pt_idx
*
3
;
xyz
+=
bs_idx
*
n
*
3
;
idx
+=
bs_idx
*
m
*
nsample
+
pt_idx
*
nsample
;
float
max_radius2
=
max_radius
*
max_radius
;
float
min_radius2
=
min_radius
*
min_radius
;
T
new_x
=
new_xyz
[
0
];
T
new_y
=
new_xyz
[
1
];
T
new_z
=
new_xyz
[
2
];
int
cnt
=
0
;
for
(
int
k
=
0
;
k
<
n
;
++
k
)
{
T
x
=
xyz
[
k
*
3
+
0
];
T
y
=
xyz
[
k
*
3
+
1
];
T
z
=
xyz
[
k
*
3
+
2
];
T
d2
=
(
new_x
-
x
)
*
(
new_x
-
x
)
+
(
new_y
-
y
)
*
(
new_y
-
y
)
+
(
new_z
-
z
)
*
(
new_z
-
z
);
if
(
d2
==
0
||
(
d2
>=
min_radius2
&&
d2
<
max_radius2
))
{
if
(
cnt
==
0
)
{
for
(
int
l
=
0
;
l
<
nsample
;
++
l
)
{
idx
[
l
]
=
k
;
}
}
idx
[
cnt
]
=
k
;
++
cnt
;
if
(
cnt
>=
nsample
)
break
;
}
}
}
#endif // BALL_QUERY_CUDA_KERNEL_CUH
mmcv/ops/csrc/pytorch/ball_query.cpp
0 → 100644
View file @
f3dfc413
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void
BallQueryForwardCUDAKernelLauncher
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
Tensor
new_xyz
,
const
Tensor
xyz
,
Tensor
idx
);
void
ball_query_forward_cuda
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
Tensor
new_xyz
,
const
Tensor
xyz
,
Tensor
idx
)
{
BallQueryForwardCUDAKernelLauncher
(
b
,
n
,
m
,
min_radius
,
max_radius
,
nsample
,
new_xyz
,
xyz
,
idx
);
};
#endif
void
ball_query_forward
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
Tensor
new_xyz_tensor
,
Tensor
xyz_tensor
,
Tensor
idx_tensor
)
{
if
(
new_xyz_tensor
.
device
().
is_cuda
())
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT
(
new_xyz_tensor
);
CHECK_CUDA_INPUT
(
xyz_tensor
);
ball_query_forward_cuda
(
b
,
n
,
m
,
min_radius
,
max_radius
,
nsample
,
new_xyz_tensor
,
xyz_tensor
,
idx_tensor
);
#else
AT_ERROR
(
"ball_query is not compiled with GPU support"
);
#endif
}
else
{
AT_ERROR
(
"ball_query is not implemented on CPU"
);
}
}
mmcv/ops/csrc/pytorch/cuda/ball_query_cuda.cu
0 → 100644
View file @
f3dfc413
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "ball_query_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void
BallQueryForwardCUDAKernelLauncher
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
Tensor
new_xyz
,
const
Tensor
xyz
,
Tensor
idx
)
{
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
at
::
cuda
::
CUDAGuard
device_guard
(
new_xyz
.
device
());
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
blocks
(
DIVUP
(
m
,
THREADS_PER_BLOCK
),
b
);
// blockIdx.x(col), blockIdx.y(row)
dim3
threads
(
THREADS_PER_BLOCK
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
new_xyz
.
scalar_type
(),
"ball_query_forward_cuda_kernel"
,
[
&
]
{
ball_query_forward_cuda_kernel
<
scalar_t
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
b
,
n
,
m
,
min_radius
,
max_radius
,
nsample
,
new_xyz
.
data_ptr
<
scalar_t
>
(),
xyz
.
data_ptr
<
scalar_t
>
(),
idx
.
data_ptr
<
int
>
());
});
AT_CUDA_CHECK
(
cudaGetLastError
());
}
mmcv/ops/csrc/pytorch/pybind.cpp
View file @
f3dfc413
...
...
@@ -111,16 +111,16 @@ Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset);
Tensor
softnms
(
Tensor
boxes
,
Tensor
scores
,
Tensor
dets
,
float
iou_threshold
,
float
sigma
,
float
min_score
,
int
method
,
int
offset
);
std
::
vector
<
std
::
vector
<
int
>
>
nms_match
(
Tensor
dets
,
float
iou_threshold
);
std
::
vector
<
std
::
vector
<
int
>>
nms_match
(
Tensor
dets
,
float
iou_threshold
);
std
::
vector
<
std
::
vector
<
float
>
>
pixel_group
(
std
::
vector
<
std
::
vector
<
float
>>
pixel_group
(
Tensor
score
,
Tensor
mask
,
Tensor
embedding
,
Tensor
kernel_label
,
Tensor
kernel_contour
,
int
kernel_region_num
,
float
distance_threshold
);
std
::
vector
<
std
::
vector
<
int
>
>
contour_expand
(
Tensor
kernel_mask
,
Tensor
internal_kernel_label
,
int
min_kernel_area
,
int
kernel_num
);
std
::
vector
<
std
::
vector
<
int
>>
contour_expand
(
Tensor
kernel_mask
,
Tensor
internal_kernel_label
,
int
min_kernel_area
,
int
kernel_num
);
void
roi_align_forward
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
Tensor
argmax_y
,
Tensor
argmax_x
,
int
aligned_height
,
...
...
@@ -172,6 +172,10 @@ void tin_shift_forward(Tensor input, Tensor shift, Tensor output);
void
tin_shift_backward
(
Tensor
grad_output
,
Tensor
shift
,
Tensor
grad_input
);
void
ball_query_forward
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
Tensor
new_xyz_tensor
,
Tensor
xyz_tensor
,
Tensor
idx_tensor
);
Tensor
bottom_pool_forward
(
Tensor
input
);
Tensor
bottom_pool_backward
(
Tensor
input
,
Tensor
grad_output
);
...
...
@@ -415,6 +419,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"nms_rotated"
,
&
nms_rotated
,
"NMS for rotated boxes"
,
py
::
arg
(
"dets"
),
py
::
arg
(
"scores"
),
py
::
arg
(
"order"
),
py
::
arg
(
"dets_sorted"
),
py
::
arg
(
"iou_threshold"
),
py
::
arg
(
"multi_label"
));
m
.
def
(
"ball_query_forward"
,
&
ball_query_forward
,
"ball_query_forward"
,
py
::
arg
(
"b"
),
py
::
arg
(
"n"
),
py
::
arg
(
"m"
),
py
::
arg
(
"min_radius"
),
py
::
arg
(
"max_radius"
),
py
::
arg
(
"nsample"
),
py
::
arg
(
"new_xyz_tensor"
),
py
::
arg
(
"xyz_tensor"
),
py
::
arg
(
"idx_tensor"
));
m
.
def
(
"roi_align_rotated_forward"
,
&
roi_align_rotated_forward
,
"roi_align_rotated forward"
,
py
::
arg
(
"input"
),
py
::
arg
(
"rois"
),
py
::
arg
(
"output"
),
py
::
arg
(
"pooled_height"
),
py
::
arg
(
"pooled_width"
),
...
...
tests/test_ops/test_ball_query.py
0 → 100644
View file @
f3dfc413
import
pytest
import
torch
from
mmcv.ops
import
ball_query
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires CUDA support'
)
def
test_ball_query
():
new_xyz
=
torch
.
tensor
([[[
-
0.0740
,
1.3147
,
-
1.3625
],
[
-
2.2769
,
2.7817
,
-
0.2334
],
[
-
0.4003
,
2.4666
,
-
0.5116
],
[
-
0.0740
,
1.3147
,
-
1.3625
],
[
-
0.0740
,
1.3147
,
-
1.3625
]],
[[
-
2.0289
,
2.4952
,
-
0.1708
],
[
-
2.0668
,
6.0278
,
-
0.4875
],
[
0.4066
,
1.4211
,
-
0.2947
],
[
-
2.0289
,
2.4952
,
-
0.1708
],
[
-
2.0289
,
2.4952
,
-
0.1708
]]]).
cuda
()
xyz
=
torch
.
tensor
([[[
-
0.0740
,
1.3147
,
-
1.3625
],
[
0.5555
,
1.0399
,
-
1.3634
],
[
-
0.4003
,
2.4666
,
-
0.5116
],
[
-
0.5251
,
2.4379
,
-
0.8466
],
[
-
0.9691
,
1.1418
,
-
1.3733
],
[
-
0.2232
,
0.9561
,
-
1.3626
],
[
-
2.2769
,
2.7817
,
-
0.2334
],
[
-
0.2822
,
1.3192
,
-
1.3645
],
[
0.1533
,
1.5024
,
-
1.0432
],
[
0.4917
,
1.1529
,
-
1.3496
]],
[[
-
2.0289
,
2.4952
,
-
0.1708
],
[
-
0.7188
,
0.9956
,
-
0.5096
],
[
-
2.0668
,
6.0278
,
-
0.4875
],
[
-
1.9304
,
3.3092
,
0.6610
],
[
0.0949
,
1.4332
,
0.3140
],
[
-
1.2879
,
2.0008
,
-
0.7791
],
[
-
0.7252
,
0.9611
,
-
0.6371
],
[
0.4066
,
1.4211
,
-
0.2947
],
[
0.3220
,
1.4447
,
0.3548
],
[
-
0.9744
,
2.3856
,
-
1.2000
]]]).
cuda
()
idx
=
ball_query
(
0
,
0.2
,
5
,
xyz
,
new_xyz
)
expected_idx
=
torch
.
tensor
([[[
0
,
0
,
0
,
0
,
0
],
[
6
,
6
,
6
,
6
,
6
],
[
2
,
2
,
2
,
2
,
2
],
[
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
]],
[[
0
,
0
,
0
,
0
,
0
],
[
2
,
2
,
2
,
2
,
2
],
[
7
,
7
,
7
,
7
,
7
],
[
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
]]]).
cuda
()
assert
torch
.
all
(
idx
==
expected_idx
)
# test dilated ball query
idx
=
ball_query
(
0.2
,
0.4
,
5
,
xyz
,
new_xyz
)
expected_idx
=
torch
.
tensor
([[[
0
,
5
,
7
,
0
,
0
],
[
6
,
6
,
6
,
6
,
6
],
[
2
,
3
,
2
,
2
,
2
],
[
0
,
5
,
7
,
0
,
0
],
[
0
,
5
,
7
,
0
,
0
]],
[[
0
,
0
,
0
,
0
,
0
],
[
2
,
2
,
2
,
2
,
2
],
[
7
,
7
,
7
,
7
,
7
],
[
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
]]]).
cuda
()
assert
torch
.
all
(
idx
==
expected_idx
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment