Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0230fc3b
Unverified
Commit
0230fc3b
authored
Jun 07, 2022
by
pc
Committed by
GitHub
Jun 07, 2022
Browse files
[Feature] Add rotated_feature_align_cpu and update iod3d in parrots (#2027)
parent
4061fcdc
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
132 additions
and
128 deletions
+132
-128
mmcv/ops/csrc/parrots/cudabind.cpp
mmcv/ops/csrc/parrots/cudabind.cpp
+43
-58
mmcv/ops/csrc/parrots/iou3d.cpp
mmcv/ops/csrc/parrots/iou3d.cpp
+23
-43
mmcv/ops/csrc/parrots/iou3d_parrots.cpp
mmcv/ops/csrc/parrots/iou3d_parrots.cpp
+20
-19
mmcv/ops/csrc/parrots/iou3d_pytorch.h
mmcv/ops/csrc/parrots/iou3d_pytorch.h
+5
-6
mmcv/ops/csrc/parrots/rotated_feature_align_parrots.cpp
mmcv/ops/csrc/parrots/rotated_feature_align_parrots.cpp
+41
-2
No files found.
mmcv/ops/csrc/parrots/cudabind.cpp
View file @
0230fc3b
...
...
@@ -564,73 +564,58 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA,
REGISTER_DEVICE_IMPL
(
group_points_backward_impl
,
CUDA
,
group_points_backward_cuda
);
void
IoU3DBoxesOverlapBevForwardCUDAKernelLauncher
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_overlap
);
void
IoU3DBoxesIoUBevForwardCUDAKernelLauncher
(
const
int
num_a
,
void
IoU3DBoxesIoU3DForwardCUDAKernelLauncher
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_iou
);
void
IoU3DNMSForwardCUDAKernelLauncher
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
void
IoU3DNMS3DForwardCUDAKernelLauncher
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
);
void
IoU3DNMSNormalForwardCUDAKernelLauncher
(
const
Tensor
boxes
,
void
IoU3DNMS
3D
NormalForwardCUDAKernelLauncher
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
);
void
iou3d_boxes_overlap_bev_forward_cuda
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_overlap
)
{
IoU3DBoxesOverlapBevForwardCUDAKernelLauncher
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
};
void
iou3d_boxes_iou_bev_forward_cuda
(
const
int
num_a
,
const
Tensor
boxes_a
,
void
iou3d_boxes_iou3d_forward_cuda
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_iou
)
{
IoU3DBoxesIoU
Bev
ForwardCUDAKernelLauncher
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
IoU3DBoxesIoU
3D
ForwardCUDAKernelLauncher
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
};
void
iou3d_nms_forward_cuda
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
void
iou3d_nms
3d
_forward_cuda
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
IoU3DNMSForwardCUDAKernelLauncher
(
boxes
,
mask
,
boxes_num
,
nms_overlap_thresh
);
IoU3DNMS3DForwardCUDAKernelLauncher
(
boxes
,
mask
,
boxes_num
,
nms_overlap_thresh
);
};
void
iou3d_nms_normal_forward_cuda
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
IoU3DNMSNormalForwardCUDAKernelLauncher
(
boxes
,
mask
,
boxes_num
,
void
iou3d_nms3d_normal_forward_cuda
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
IoU3DNMS3DNormalForwardCUDAKernelLauncher
(
boxes
,
mask
,
boxes_num
,
nms_overlap_thresh
);
};
void
iou3d_boxes_overlap_bev_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_overlap
);
void
iou3d_boxes_iou_bev_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
void
iou3d_boxes_iou3d_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_iou
);
void
iou3d_nms_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
void
iou3d_nms
3d
_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
);
void
iou3d_nms_normal_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
);
void
iou3d_nms3d_normal_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
);
REGISTER_DEVICE_IMPL
(
iou3d_boxes_overlap_bev_forward_impl
,
CUDA
,
iou3d_boxes_overlap_bev_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_boxes_iou_bev_forward_impl
,
CUDA
,
iou3d_boxes_iou_bev_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_nms_forward_impl
,
CUDA
,
iou3d_nms_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_nms_normal_forward_impl
,
CUDA
,
iou3d_nms_normal_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_boxes_iou3d_forward_impl
,
CUDA
,
iou3d_boxes_iou3d_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_nms3d_forward_impl
,
CUDA
,
iou3d_nms3d_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_nms3d_normal_forward_impl
,
CUDA
,
iou3d_nms3d_normal_forward_cuda
);
void
KNNForwardCUDAKernelLauncher
(
int
b
,
int
n
,
int
m
,
int
nsample
,
const
Tensor
xyz
,
const
Tensor
new_xyz
,
...
...
mmcv/ops/csrc/parrots/iou3d.cpp
View file @
0230fc3b
...
...
@@ -12,59 +12,39 @@ All Rights Reserved 2019-2020.
const
int
THREADS_PER_BLOCK_NMS
=
sizeof
(
unsigned
long
long
)
*
8
;
void
iou3d_boxes_overlap_bev_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_overlap
)
{
DISPATCH_DEVICE_IMPL
(
iou3d_boxes_overlap_bev_forward_impl
,
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
}
void
iou3d_boxes_iou_bev_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
void
iou3d_boxes_iou3d_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_iou
)
{
DISPATCH_DEVICE_IMPL
(
iou3d_boxes_iou
_bev
_forward_impl
,
num_a
,
boxes_a
,
num_b
,
DISPATCH_DEVICE_IMPL
(
iou3d_boxes_iou
3d
_forward_impl
,
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
}
void
iou3d_nms_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
void
iou3d_nms
3d
_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
DISPATCH_DEVICE_IMPL
(
iou3d_nms_forward_impl
,
boxes
,
mask
,
boxes_num
,
DISPATCH_DEVICE_IMPL
(
iou3d_nms
3d
_forward_impl
,
boxes
,
mask
,
boxes_num
,
nms_overlap_thresh
);
}
void
iou3d_nms_normal_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
DISPATCH_DEVICE_IMPL
(
iou3d_nms_normal_forward_impl
,
boxes
,
mask
,
boxes_num
,
void
iou3d_nms3d_normal_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
)
{
DISPATCH_DEVICE_IMPL
(
iou3d_nms3d_normal_forward_impl
,
boxes
,
mask
,
boxes_num
,
nms_overlap_thresh
);
}
void
iou3d_boxes_overlap_bev_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_overlap
)
{
// params boxes_a: (N, 5) [x1, y1, x2, y2, ry]
// params boxes_b: (M, 5)
// params ans_overlap: (N, M)
int
num_a
=
boxes_a
.
size
(
0
);
int
num_b
=
boxes_b
.
size
(
0
);
iou3d_boxes_overlap_bev_forward_impl
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
}
void
iou3d_boxes_iou_bev_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_iou
)
{
// params boxes_a: (N, 5) [x1, y1, x2, y2, ry]
void
iou3d_boxes_iou3d_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_iou
)
{
// params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 5)
// params ans_overlap: (N, M)
int
num_a
=
boxes_a
.
size
(
0
);
int
num_b
=
boxes_b
.
size
(
0
);
iou3d_boxes_iou
_bev
_forward_impl
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
iou3d_boxes_iou
3d
_forward_impl
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
}
void
iou3d_nms_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
void
iou3d_nms
3d
_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
float
nms_overlap_thresh
)
{
// params boxes: (N,
5
) [x
1
, y
1
,
x2, y2, ry
]
// params boxes: (N,
7
) [x, y,
z, dx, dy, dz, heading
]
// params keep: (N)
CHECK_CONTIGUOUS
(
boxes
);
CHECK_CONTIGUOUS
(
keep
);
...
...
@@ -80,7 +60,7 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num,
at
::
empty
({
boxes_num
,
col_blocks
},
boxes
.
options
().
dtype
(
at
::
kLong
));
unsigned
long
long
*
mask_data
=
(
unsigned
long
long
*
)
mask
.
data_ptr
<
int64_t
>
();
iou3d_nms_forward_impl
(
boxes
,
mask_data
,
boxes_num
,
nms_overlap_thresh
);
iou3d_nms
3d
_forward_impl
(
boxes
,
mask_data
,
boxes_num
,
nms_overlap_thresh
);
at
::
Tensor
mask_cpu
=
mask
.
to
(
at
::
kCPU
);
unsigned
long
long
*
mask_host
=
...
...
@@ -106,9 +86,9 @@ void iou3d_nms_forward(Tensor boxes, Tensor keep, Tensor keep_num,
}
}
void
iou3d_nms_normal_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
void
iou3d_nms
3d
_normal_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
float
nms_overlap_thresh
)
{
// params boxes: (N,
5
) [x
1
, y
1
,
x2, y2, ry
]
// params boxes: (N,
7
) [x, y,
z, dx, dy, dz, heading
]
// params keep: (N)
CHECK_CONTIGUOUS
(
boxes
);
...
...
@@ -125,7 +105,7 @@ void iou3d_nms_normal_forward(Tensor boxes, Tensor keep, Tensor keep_num,
at
::
empty
({
boxes_num
,
col_blocks
},
boxes
.
options
().
dtype
(
at
::
kLong
));
unsigned
long
long
*
mask_data
=
(
unsigned
long
long
*
)
mask
.
data_ptr
<
int64_t
>
();
iou3d_nms_normal_forward_impl
(
boxes
,
mask_data
,
boxes_num
,
iou3d_nms
3d
_normal_forward_impl
(
boxes
,
mask_data
,
boxes_num
,
nms_overlap_thresh
);
at
::
Tensor
mask_cpu
=
mask
.
to
(
at
::
kCPU
);
...
...
mmcv/ops/csrc/parrots/iou3d_parrots.cpp
View file @
0230fc3b
...
...
@@ -8,18 +8,19 @@
using
namespace
parrots
;
#ifdef MMCV_WITH_CUDA
void
iou3d_boxes_iou_bev_forward_cuda_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
void
iou3d_boxes_iou3d_forward_cuda_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
auto
boxes_a
=
buildATensor
(
ctx
,
ins
[
0
]);
auto
boxes_b
=
buildATensor
(
ctx
,
ins
[
1
]);
auto
ans_iou
=
buildATensor
(
ctx
,
outs
[
0
]);
iou3d_boxes_iou
_bev
_forward
(
boxes_a
,
boxes_b
,
ans_iou
);
iou3d_boxes_iou
3d
_forward
(
boxes_a
,
boxes_b
,
ans_iou
);
}
void
iou3d_nms_forward_cuda_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
void
iou3d_nms
3d
_forward_cuda_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
float
nms_overlap_thresh
;
...
...
@@ -30,10 +31,10 @@ void iou3d_nms_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
auto
keep
=
buildATensor
(
ctx
,
outs
[
0
]);
auto
keep_num
=
buildATensor
(
ctx
,
outs
[
1
]);
iou3d_nms_forward
(
boxes
,
keep
,
keep_num
,
nms_overlap_thresh
);
iou3d_nms
3d
_forward
(
boxes
,
keep
,
keep_num
,
nms_overlap_thresh
);
}
void
iou3d_nms_normal_forward_cuda_parrots
(
CudaContext
&
ctx
,
void
iou3d_nms
3d
_normal_forward_cuda_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
...
...
@@ -45,26 +46,26 @@ void iou3d_nms_normal_forward_cuda_parrots(CudaContext& ctx,
auto
keep
=
buildATensor
(
ctx
,
outs
[
0
]);
auto
keep_num
=
buildATensor
(
ctx
,
outs
[
1
]);
iou3d_nms_normal_forward
(
boxes
,
keep
,
keep_num
,
nms_overlap_thresh
);
iou3d_nms
3d
_normal_forward
(
boxes
,
keep
,
keep_num
,
nms_overlap_thresh
);
}
PARROTS_EXTENSION_REGISTER
(
iou3d_boxes_iou
_bev
_forward
)
PARROTS_EXTENSION_REGISTER
(
iou3d_boxes_iou
3d
_forward
)
.
input
(
2
)
.
output
(
1
)
.
apply
(
iou3d_boxes_iou
_bev
_forward_cuda_parrots
)
.
apply
(
iou3d_boxes_iou
3d
_forward_cuda_parrots
)
.
done
();
PARROTS_EXTENSION_REGISTER
(
iou3d_nms_forward
)
PARROTS_EXTENSION_REGISTER
(
iou3d_nms
3d
_forward
)
.
attr
(
"nms_overlap_thresh"
)
.
input
(
1
)
.
output
(
2
)
.
apply
(
iou3d_nms_forward_cuda_parrots
)
.
apply
(
iou3d_nms
3d
_forward_cuda_parrots
)
.
done
();
PARROTS_EXTENSION_REGISTER
(
iou3d_nms_normal_forward
)
PARROTS_EXTENSION_REGISTER
(
iou3d_nms
3d
_normal_forward
)
.
attr
(
"nms_overlap_thresh"
)
.
input
(
1
)
.
output
(
2
)
.
apply
(
iou3d_nms_normal_forward_cuda_parrots
)
.
apply
(
iou3d_nms
3d
_normal_forward_cuda_parrots
)
.
done
();
#endif
mmcv/ops/csrc/parrots/iou3d_pytorch.h
View file @
0230fc3b
...
...
@@ -4,13 +4,12 @@
#include <torch/extension.h>
using
namespace
at
;
void
iou3d_boxes_iou_bev_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_iou
);
void
iou3d_boxes_iou3d_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_iou
);
void
iou3d_nms_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
void
iou3d_nms
3d
_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
float
nms_overlap_thresh
);
void
iou3d_nms_normal_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
void
iou3d_nms
3d
_normal_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
float
nms_overlap_thresh
);
#endif // IOU_3D_PYTORCH_H
mmcv/ops/csrc/parrots/rotated_feature_align_parrots.cpp
View file @
0230fc3b
...
...
@@ -41,12 +41,50 @@ void rotated_feature_align_backward_cuda_parrots(
spatial_scale
,
points
);
}
void
rotated_feature_align_forward_cpu_parrots
(
HostContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
float
spatial_scale
;
int
points
;
SSAttrs
(
attr
)
.
get
<
float
>
(
"spatial_scale"
,
spatial_scale
)
.
get
<
int
>
(
"points"
,
points
)
.
done
();
auto
features
=
buildATensor
(
ctx
,
ins
[
0
]);
auto
best_bboxes
=
buildATensor
(
ctx
,
ins
[
1
]);
auto
output
=
buildATensor
(
ctx
,
outs
[
0
]);
rotated_feature_align_forward
(
features
,
best_bboxes
,
output
,
spatial_scale
,
points
);
}
#endif
void
rotated_feature_align_backward_cpu_parrots
(
HostContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
float
spatial_scale
;
int
points
;
SSAttrs
(
attr
)
.
get
<
float
>
(
"spatial_scale"
,
spatial_scale
)
.
get
<
int
>
(
"points"
,
points
)
.
done
();
auto
grad_output
=
buildATensor
(
ctx
,
ins
[
0
]);
auto
best_bboxes
=
buildATensor
(
ctx
,
ins
[
1
]);
auto
grad_input
=
buildATensor
(
ctx
,
outs
[
0
]);
rotated_feature_align_backward
(
grad_output
,
best_bboxes
,
grad_input
,
spatial_scale
,
points
);
}
PARROTS_EXTENSION_REGISTER
(
rotated_feature_align_forward
)
.
attr
(
"spatial_scale"
)
.
attr
(
"points"
)
.
input
(
2
)
.
output
(
1
)
.
apply
(
rotated_feature_align_forward_cpu_parrots
)
#ifdef MMCV_WITH_CUDA
.
apply
(
rotated_feature_align_forward_cuda_parrots
)
#endif
.
done
();
PARROTS_EXTENSION_REGISTER
(
rotated_feature_align_backward
)
...
...
@@ -54,7 +92,8 @@ PARROTS_EXTENSION_REGISTER(rotated_feature_align_backward)
.
attr
(
"points"
)
.
input
(
2
)
.
output
(
1
)
.
apply
(
rotated_feature_align_forward_cpu_parrots
)
#ifdef MMCV_WITH_CUDA
.
apply
(
rotated_feature_align_backward_cuda_parrots
)
.
done
();
#endif
.
done
();
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment