Unverified Commit 9f30496c authored by pc's avatar pc Committed by GitHub
Browse files

[Fix] Fix iou3d in parrots (#2054)

parent 9807c2d2
...@@ -564,11 +564,11 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA, ...@@ -564,11 +564,11 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA,
REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA, REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA,
group_points_backward_cuda); group_points_backward_cuda);
void IoU3DBoxesIoU3DForwardCUDAKernelLauncher(const int num_a, void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a, const Tensor boxes_a,
const int num_b, const int num_b,
const Tensor boxes_b, const Tensor boxes_b,
Tensor ans_iou); Tensor ans_overlap);
void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long* mask, unsigned long long* mask,
...@@ -580,11 +580,11 @@ void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, ...@@ -580,11 +580,11 @@ void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
int boxes_num, int boxes_num,
float nms_overlap_thresh); float nms_overlap_thresh);
void iou3d_boxes_iou3d_forward_cuda(const int num_a, const Tensor boxes_a, void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b, const int num_b, const Tensor boxes_b,
Tensor ans_iou) { Tensor ans_overlap) {
IoU3DBoxesIoU3DForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b, IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b,
ans_iou); ans_overlap);
}; };
void iou3d_nms3d_forward_cuda(const Tensor boxes, unsigned long long* mask, void iou3d_nms3d_forward_cuda(const Tensor boxes, unsigned long long* mask,
...@@ -600,9 +600,9 @@ void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, ...@@ -600,9 +600,9 @@ void iou3d_nms3d_normal_forward_cuda(const Tensor boxes,
nms_overlap_thresh); nms_overlap_thresh);
}; };
void iou3d_boxes_iou3d_forward_impl(const int num_a, const Tensor boxes_a, void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b, const int num_b, const Tensor boxes_b,
Tensor ans_iou); Tensor ans_overlap);
void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long* mask, void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long* mask,
int boxes_num, float nms_overlap_thresh); int boxes_num, float nms_overlap_thresh);
...@@ -611,8 +611,8 @@ void iou3d_nms3d_normal_forward_impl(const Tensor boxes, ...@@ -611,8 +611,8 @@ void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
unsigned long long* mask, int boxes_num, unsigned long long* mask, int boxes_num,
float nms_overlap_thresh); float nms_overlap_thresh);
REGISTER_DEVICE_IMPL(iou3d_boxes_iou3d_forward_impl, CUDA, REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA,
iou3d_boxes_iou3d_forward_cuda); iou3d_boxes_overlap_bev_forward_cuda);
REGISTER_DEVICE_IMPL(iou3d_nms3d_forward_impl, CUDA, iou3d_nms3d_forward_cuda); REGISTER_DEVICE_IMPL(iou3d_nms3d_forward_impl, CUDA, iou3d_nms3d_forward_cuda);
REGISTER_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, CUDA, REGISTER_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, CUDA,
iou3d_nms3d_normal_forward_cuda); iou3d_nms3d_normal_forward_cuda);
......
...@@ -12,11 +12,11 @@ All Rights Reserved 2019-2020. ...@@ -12,11 +12,11 @@ All Rights Reserved 2019-2020.
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
void iou3d_boxes_iou3d_forward_impl(const int num_a, const Tensor boxes_a, void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b, const int num_b, const Tensor boxes_b,
Tensor ans_iou) { Tensor ans_overlap) {
DISPATCH_DEVICE_IMPL(iou3d_boxes_iou3d_forward_impl, num_a, boxes_a, num_b, DISPATCH_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, num_a, boxes_a,
boxes_b, ans_iou); num_b, boxes_b, ans_overlap);
} }
void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long *mask, void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long *mask,
...@@ -32,14 +32,16 @@ void iou3d_nms3d_normal_forward_impl(const Tensor boxes, ...@@ -32,14 +32,16 @@ void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
nms_overlap_thresh); nms_overlap_thresh);
} }
void iou3d_boxes_iou3d_forward(Tensor boxes_a, Tensor boxes_b, Tensor ans_iou) { void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b,
Tensor ans_overlap) {
// params boxes: (N, 7) [x, y, z, dx, dy, dz, heading] // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 5) // params boxes_b: (M, 5)
// params ans_overlap: (N, M) // params ans_overlap: (N, M)
int num_a = boxes_a.size(0); int num_a = boxes_a.size(0);
int num_b = boxes_b.size(0); int num_b = boxes_b.size(0);
iou3d_boxes_iou3d_forward_impl(num_a, boxes_a, num_b, boxes_b, ans_iou); iou3d_boxes_overlap_bev_forward_impl(num_a, boxes_a, num_b, boxes_b,
ans_overlap);
} }
void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num, void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num,
......
...@@ -8,16 +8,15 @@ ...@@ -8,16 +8,15 @@
using namespace parrots; using namespace parrots;
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
void iou3d_boxes_iou3d_forward_cuda_parrots(CudaContext& ctx, void iou3d_boxes_overlap_bev_forward_cuda_parrots(
const SSElement& attr, CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) { OperatorBase::out_list_t& outs) {
auto boxes_a = buildATensor(ctx, ins[0]); auto boxes_a = buildATensor(ctx, ins[0]);
auto boxes_b = buildATensor(ctx, ins[1]); auto boxes_b = buildATensor(ctx, ins[1]);
auto ans_iou = buildATensor(ctx, outs[0]); auto ans_iou = buildATensor(ctx, outs[0]);
iou3d_boxes_iou3d_forward(boxes_a, boxes_b, ans_iou); iou3d_boxes_overlap_bev_forward(boxes_a, boxes_b, ans_iou);
} }
void iou3d_nms3d_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, void iou3d_nms3d_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
...@@ -49,10 +48,10 @@ void iou3d_nms3d_normal_forward_cuda_parrots(CudaContext& ctx, ...@@ -49,10 +48,10 @@ void iou3d_nms3d_normal_forward_cuda_parrots(CudaContext& ctx,
iou3d_nms3d_normal_forward(boxes, keep, keep_num, nms_overlap_thresh); iou3d_nms3d_normal_forward(boxes, keep, keep_num, nms_overlap_thresh);
} }
PARROTS_EXTENSION_REGISTER(iou3d_boxes_iou3d_forward) PARROTS_EXTENSION_REGISTER(iou3d_boxes_overlap_bev_forward)
.input(2) .input(2)
.output(1) .output(1)
.apply(iou3d_boxes_iou3d_forward_cuda_parrots) .apply(iou3d_boxes_overlap_bev_forward_cuda_parrots)
.done(); .done();
PARROTS_EXTENSION_REGISTER(iou3d_nms3d_forward) PARROTS_EXTENSION_REGISTER(iou3d_nms3d_forward)
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
#include <torch/extension.h> #include <torch/extension.h>
using namespace at; using namespace at;
void iou3d_boxes_iou3d_forward(Tensor boxes_a, Tensor boxes_b, Tensor ans_iou); void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b,
Tensor ans_overlap);
void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num, void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num,
float nms_overlap_thresh); float nms_overlap_thresh);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment