Unverified Commit 594ff3c0 authored by Yue Zhou's avatar Yue Zhou Committed by GitHub
Browse files

[Feature] Add convex IoU CUDA op for rotated detection. (#1610)



* add convex iou

* fix lint

* add convex_iou

* fix convex_iou

* add convex_giou

* fix bug

* fix lint

* fix bug

* Update

* update

* add kernel loop

* fix bug.

* fix polygen typo

* simplify reverse

* Update convex_iou_cuda_kernel.cuh

* Update mmcv/ops/convex_iou.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/convex_iou.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add   AT_DISPATCH_FLOATING_TYPES_AND_HALF

* fix lint

* fix lint

* Resolving conflicts
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 9acc892a
......@@ -9,6 +9,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- CARAFE
- CrissCrossAttention
- ContextBlock
- ConvexIoU
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
......
......@@ -9,6 +9,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- CARAFE
- CrissCrossAttention
- ContextBlock
- ConvexIoU
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
......
......@@ -8,6 +8,7 @@ from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .contour_expand import contour_expand
from .convex_iou import convex_giou, convex_iou
from .corner_pool import CornerPool
from .correlation import Correlation
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
......@@ -84,5 +85,6 @@ __all__ = [
'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter'
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter',
'convex_iou', 'convex_giou'
]
# Copyright (c) OpenMMLab. All rights reserved.
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['convex_iou', 'convex_giou'])
def convex_giou(pointsets, polygons):
"""Return generalized intersection-over-union (Jaccard index) between point
sets and polygons.
Args:
pointsets (torch.Tensor): It has shape (N, 18),
indicating (x1, y1, x2, y2, ..., x9, y9) for each row.
polygons (torch.Tensor): It has shape (N, 8),
indicating (x1, y1, x2, y2, x3, y3, x4, y4) for each row.
Returns:
tuple[torch.Tensor, torch.Tensor]: The first element is the gious
between point sets and polygons with the shape (N,). The second
element is the gradient of point sets with the shape (N, 18).
"""
output = pointsets.new_zeros((pointsets.size(0), 19))
ext_module.convex_giou(pointsets, polygons, output)
convex_giou = output[:, -1]
points_grad = output[:, 0:-1]
return convex_giou, points_grad
def convex_iou(pointsets, polygons):
"""Return intersection-over-union (Jaccard index) between point sets and
polygons.
Args:
pointsets (torch.Tensor): It has shape (N, 18),
indicating (x1, y1, x2, y2, ..., x9, y9) for each row.
polygons (torch.Tensor): It has shape (K, 8),
indicating (x1, y1, x2, y2, x3, y3, x4, y4) for each row.
Returns:
torch.Tensor: Return the ious between point sets and polygons with the
shape (N, K).
"""
N, K = pointsets.size(0), polygons.size(0)
ious = pointsets.new_zeros((N, K))
ext_module.convex_iou(pointsets, polygons, ious)
return ious
This diff is collapsed.
// Copyright (c) OpenMMLab. All rights reserved
// modified from
// https://github.com/SDL-GuoZonghao/BeyondBoundingBox/tree/main/mmdet/ops/iou/src
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void convex_iou_impl(const Tensor pointsets, const Tensor polygons,
Tensor ious) {
DISPATCH_DEVICE_IMPL(convex_iou_impl, pointsets, polygons, ious);
}
void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious) {
convex_iou_impl(pointsets, polygons, ious);
}
void convex_giou_impl(const Tensor pointsets, const Tensor polygons,
Tensor output) {
DISPATCH_DEVICE_IMPL(convex_giou_impl, pointsets, polygons, output);
}
void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output) {
convex_giou_impl(pointsets, polygons, output);
}
// Copyright (c) OpenMMLab. All rights reserved
// modified from
// https://github.com/SDL-GuoZonghao/BeyondBoundingBox/blob/main/mmdet/ops/iou/src/convex_iou_kernel.cu
#include "convex_iou_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void ConvexIoUCUDAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor ious) {
int output_size = ious.numel();
int num_pointsets = pointsets.size(0);
int num_polygons = polygons.size(0);
at::cuda::CUDAGuard device_guard(pointsets.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pointsets.scalar_type(), "convex_iou_cuda_kernel", ([&] {
convex_iou_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_pointsets, num_polygons, pointsets.data_ptr<scalar_t>(),
polygons.data_ptr<scalar_t>(), ious.data_ptr<scalar_t>());
}));
AT_CUDA_CHECK(cudaGetLastError());
}
void ConvexGIoUCUDAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor output) {
int output_size = output.numel();
int num_pointsets = pointsets.size(0);
int num_polygons = polygons.size(0);
at::cuda::CUDAGuard device_guard(pointsets.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pointsets.scalar_type(), "convex_giou_cuda_kernel", ([&] {
convex_giou_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK / 2, 0, stream>>>(
num_pointsets, num_polygons, pointsets.data_ptr<scalar_t>(),
polygons.data_ptr<scalar_t>(), output.data_ptr<scalar_t>());
}));
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -1539,3 +1539,28 @@ REGISTER_DEVICE_IMPL(active_rotated_filter_forward_impl, CUDA,
active_rotated_filter_forward_cuda);
REGISTER_DEVICE_IMPL(active_rotated_filter_backward_impl, CUDA,
active_rotated_filter_backward_cuda);
void ConvexIoUCUDAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor ious);
void ConvexGIoUCUDAKernelLauncher(const Tensor pointsets, const Tensor polygons,
Tensor output);
void convex_iou_cuda(const Tensor pointsets, const Tensor polygons,
Tensor ious) {
ConvexIoUCUDAKernelLauncher(pointsets, polygons, ious);
}
void convex_giou_cuda(const Tensor pointsets, const Tensor polygons,
Tensor output) {
ConvexGIoUCUDAKernelLauncher(pointsets, polygons, output);
}
void convex_iou_impl(const Tensor pointsets, const Tensor polygons,
Tensor ious);
void convex_giou_impl(const Tensor pointsets, const Tensor polygons,
Tensor output);
REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda);
REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda);
......@@ -371,6 +371,10 @@ void active_rotated_filter_forward(const Tensor input, const Tensor indices,
void active_rotated_filter_backward(const Tensor grad_out, const Tensor indices,
Tensor grad_in);
void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious);
void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
......@@ -747,4 +751,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("active_rotated_filter_backward", &active_rotated_filter_backward,
"active_rotated_filter_backward", py::arg("grad_out"),
py::arg("indices"), py::arg("grad_in"));
m.def("convex_iou", &convex_iou, "convex_iou", py::arg("pointsets"),
py::arg("polygons"), py::arg("ious"));
m.def("convex_giou", &convex_giou, "convex_giou", py::arg("pointsets"),
py::arg("polygons"), py::arg("output"));
}
import numpy as np
import pytest
import torch
from mmcv.ops import convex_giou, convex_iou
np_pointsets = np.asarray([[
1.0, 1.0, 2.0, 2.0, 1.0, 2.0, 2.0, 1.0, 1.0, 3.0, 3.0, 1.0, 2.0, 3.0, 3.0,
2.0, 1.5, 1.5
],
[
1.5, 1.5, 2.5, 2.5, 1.5, 2.5, 2.5, 1.5, 1.5,
3.5, 3.5, 1.5, 2.5, 3.5, 3.5, 2.5, 2.0, 2.0
]])
np_polygons = np.asarray([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 1.0],
[1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0, 1.0]])
np_expected_iou = np.asarray([[0.2857, 0.8750], [0.0588, 0.4286]])
np_expected_giou = np.asarray([0.2857, 0.3831])
np_expected_grad = np.asarray([[
0.0204, 0.0408, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0612,
-0.0408, -0.0408, 0.0816, -0.0408, -0.0816, -0.0816, -0.0408, 0.0000,
0.0000
],
[
-0.1848, -0.1848, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, -0.1076, -0.0801,
-0.0801, -0.1076, -0.0367, -0.0734, -0.0734,
-0.0367, 0.0000, 0.0000
]])
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_convex_iou():
pointsets = torch.from_numpy(np_pointsets).cuda().float()
polygons = torch.from_numpy(np_polygons).cuda().float()
expected_iou = torch.from_numpy(np_expected_iou).cuda().float()
assert torch.allclose(
convex_iou(pointsets, polygons), expected_iou, atol=1e-3)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_convex_giou():
pointsets = torch.from_numpy(np_pointsets).cuda().float()
polygons = torch.from_numpy(np_polygons).cuda().float()
expected_giou = torch.from_numpy(np_expected_giou).cuda().float()
expected_grad = torch.from_numpy(np_expected_grad).cuda().float()
giou, grad = convex_giou(pointsets, polygons)
assert torch.allclose(giou, expected_giou, atol=1e-3)
assert torch.allclose(grad, expected_grad, atol=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment