Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
9f30496c
Unverified
Commit
9f30496c
authored
Jun 11, 2022
by
pc
Committed by
GitHub
Jun 11, 2022
Browse files
[Fix] Fix iou3d in parrots (#2054)
parent
9807c2d2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
30 deletions
+32
-30
mmcv/ops/csrc/parrots/cudabind.cpp
mmcv/ops/csrc/parrots/cudabind.cpp
+15
-15
mmcv/ops/csrc/parrots/iou3d.cpp
mmcv/ops/csrc/parrots/iou3d.cpp
+9
-7
mmcv/ops/csrc/parrots/iou3d_parrots.cpp
mmcv/ops/csrc/parrots/iou3d_parrots.cpp
+6
-7
mmcv/ops/csrc/parrots/iou3d_pytorch.h
mmcv/ops/csrc/parrots/iou3d_pytorch.h
+2
-1
No files found.
mmcv/ops/csrc/parrots/cudabind.cpp
View file @
9f30496c
...
...
@@ -564,11 +564,11 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA,
REGISTER_DEVICE_IMPL
(
group_points_backward_impl
,
CUDA
,
group_points_backward_cuda
);
void
IoU3DBoxes
IoU3D
ForwardCUDAKernelLauncher
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_
iou
);
void
IoU3DBoxes
OverlapBev
ForwardCUDAKernelLauncher
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_
overlap
);
void
IoU3DNMS3DForwardCUDAKernelLauncher
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
...
...
@@ -580,11 +580,11 @@ void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
int
boxes_num
,
float
nms_overlap_thresh
);
void
iou3d_boxes_
iou3d
_forward_cuda
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_
iou
)
{
IoU3DBoxes
IoU3D
ForwardCUDAKernelLauncher
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
void
iou3d_boxes_
overlap_bev
_forward_cuda
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_
overlap
)
{
IoU3DBoxes
OverlapBev
ForwardCUDAKernelLauncher
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
};
void
iou3d_nms3d_forward_cuda
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
...
...
@@ -600,9 +600,9 @@ void iou3d_nms3d_normal_forward_cuda(const Tensor boxes,
nms_overlap_thresh
);
};
void
iou3d_boxes_
iou3d
_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_
iou
);
void
iou3d_boxes_
overlap_bev
_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_
overlap
);
void
iou3d_nms3d_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
);
...
...
@@ -611,8 +611,8 @@ void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
unsigned
long
long
*
mask
,
int
boxes_num
,
float
nms_overlap_thresh
);
REGISTER_DEVICE_IMPL
(
iou3d_boxes_
iou3d
_forward_impl
,
CUDA
,
iou3d_boxes_
iou3d
_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_boxes_
overlap_bev
_forward_impl
,
CUDA
,
iou3d_boxes_
overlap_bev
_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_nms3d_forward_impl
,
CUDA
,
iou3d_nms3d_forward_cuda
);
REGISTER_DEVICE_IMPL
(
iou3d_nms3d_normal_forward_impl
,
CUDA
,
iou3d_nms3d_normal_forward_cuda
);
...
...
mmcv/ops/csrc/parrots/iou3d.cpp
View file @
9f30496c
...
...
@@ -12,11 +12,11 @@ All Rights Reserved 2019-2020.
const
int
THREADS_PER_BLOCK_NMS
=
sizeof
(
unsigned
long
long
)
*
8
;
void
iou3d_boxes_
iou3d
_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_
iou
)
{
DISPATCH_DEVICE_IMPL
(
iou3d_boxes_
iou3d
_forward_impl
,
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_
iou
);
void
iou3d_boxes_
overlap_bev
_forward_impl
(
const
int
num_a
,
const
Tensor
boxes_a
,
const
int
num_b
,
const
Tensor
boxes_b
,
Tensor
ans_
overlap
)
{
DISPATCH_DEVICE_IMPL
(
iou3d_boxes_
overlap_bev
_forward_impl
,
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_
overlap
);
}
void
iou3d_nms3d_forward_impl
(
const
Tensor
boxes
,
unsigned
long
long
*
mask
,
...
...
@@ -32,14 +32,16 @@ void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
nms_overlap_thresh
);
}
void
iou3d_boxes_iou3d_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_iou
)
{
void
iou3d_boxes_overlap_bev_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_overlap
)
{
// params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 5)
// params ans_overlap: (N, M)
int
num_a
=
boxes_a
.
size
(
0
);
int
num_b
=
boxes_b
.
size
(
0
);
iou3d_boxes_iou3d_forward_impl
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_iou
);
iou3d_boxes_overlap_bev_forward_impl
(
num_a
,
boxes_a
,
num_b
,
boxes_b
,
ans_overlap
);
}
void
iou3d_nms3d_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
...
...
mmcv/ops/csrc/parrots/iou3d_parrots.cpp
View file @
9f30496c
...
...
@@ -8,16 +8,15 @@
using
namespace
parrots
;
#ifdef MMCV_WITH_CUDA
void
iou3d_boxes_iou3d_forward_cuda_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
void
iou3d_boxes_overlap_bev_forward_cuda_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
auto
boxes_a
=
buildATensor
(
ctx
,
ins
[
0
]);
auto
boxes_b
=
buildATensor
(
ctx
,
ins
[
1
]);
auto
ans_iou
=
buildATensor
(
ctx
,
outs
[
0
]);
iou3d_boxes_
iou3d
_forward
(
boxes_a
,
boxes_b
,
ans_iou
);
iou3d_boxes_
overlap_bev
_forward
(
boxes_a
,
boxes_b
,
ans_iou
);
}
void
iou3d_nms3d_forward_cuda_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
...
...
@@ -49,10 +48,10 @@ void iou3d_nms3d_normal_forward_cuda_parrots(CudaContext& ctx,
iou3d_nms3d_normal_forward
(
boxes
,
keep
,
keep_num
,
nms_overlap_thresh
);
}
PARROTS_EXTENSION_REGISTER
(
iou3d_boxes_
iou3d
_forward
)
PARROTS_EXTENSION_REGISTER
(
iou3d_boxes_
overlap_bev
_forward
)
.
input
(
2
)
.
output
(
1
)
.
apply
(
iou3d_boxes_
iou3d
_forward_cuda_parrots
)
.
apply
(
iou3d_boxes_
overlap_bev
_forward_cuda_parrots
)
.
done
();
PARROTS_EXTENSION_REGISTER
(
iou3d_nms3d_forward
)
...
...
mmcv/ops/csrc/parrots/iou3d_pytorch.h
View file @
9f30496c
...
...
@@ -4,7 +4,8 @@
#include <torch/extension.h>
using
namespace
at
;
void
iou3d_boxes_iou3d_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_iou
);
void
iou3d_boxes_overlap_bev_forward
(
Tensor
boxes_a
,
Tensor
boxes_b
,
Tensor
ans_overlap
);
void
iou3d_nms3d_forward
(
Tensor
boxes
,
Tensor
keep
,
Tensor
keep_num
,
float
nms_overlap_thresh
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment