Unverified Commit 0e590362 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Enhancement] Optimize bbox overlap (#1718)

* add half support

* add cpu implementation

* fix bugs, load with inline asm

* better vector load

* add comments
parent 86f8ade9
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps'])
def _bbox_overlaps_cpu(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
assert mode in ['iou', 'iof']
if aligned:
lt = torch.max(bboxes1[:, :2], bboxes2[:, :2]) # [rows, 2]
rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:]) # [rows, 2]
wh = (rb - lt + offset).clamp(min=0) # [rows, 2]
overlap = wh[:, 0] * wh[:, 1]
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (
bboxes1[:, 3] - bboxes1[:, 1] + offset)
if mode == 'iou':
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (
bboxes2[:, 3] - bboxes2[:, 1] + offset)
ious = overlap / (area1 + area2 - overlap)
else:
ious = overlap / area1
else:
lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [rows, cols, 2]
rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [rows, cols, 2]
wh = (rb - lt + offset).clamp(min=0) # [rows, cols, 2]
overlap = wh[:, :, 0] * wh[:, :, 1]
area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (
bboxes1[:, 3] - bboxes1[:, 1] + offset)
if mode == 'iou':
area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (
bboxes2[:, 3] - bboxes2[:, 1] + offset)
ious = overlap / (area1[:, None] + area2 - overlap)
else:
ious = overlap / (area1[:, None])
return ious
def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
"""Calculate overlap between two set of bboxes.
......@@ -65,10 +104,19 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
if rows * cols == 0:
return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
if aligned:
ious = bboxes1.new_zeros(rows)
if bboxes1.device.type == 'cpu':
return _bbox_overlaps_cpu(
bboxes1, bboxes2, mode=mode, aligned=aligned, offset=offset)
else:
ious = bboxes1.new_zeros((rows, cols))
ext_module.bbox_overlaps(
bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
return ious
if aligned:
ious = bboxes1.new_zeros(rows)
else:
ious = bboxes1.new_zeros((rows, cols))
ext_module.bbox_overlaps(
bboxes1,
bboxes2,
ious,
mode=mode_flag,
aligned=aligned,
offset=offset)
return ious
......@@ -8,6 +8,27 @@
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__device__ __forceinline__ void load_bbox(const T* bbox, const int base, T& x1,
T& y1, T& x2, T& y2) {
x1 = bbox[base];
y1 = bbox[base + 1];
x2 = bbox[base + 2];
y2 = bbox[base + 3];
}
template <>
__device__ __forceinline__ void load_bbox<float>(const float* bbox,
const int base, float& x1,
float& y1, float& x2,
float& y2) {
const float4 bbox_offset = reinterpret_cast<const float4*>(bbox + base)[0];
x1 = bbox_offset.x;
y1 = bbox_offset.y;
x2 = bbox_offset.z;
y2 = bbox_offset.w;
}
template <typename T>
__global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
T* ious, const int num_bbox1,
......@@ -16,69 +37,109 @@ __global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
const int offset) {
if (aligned) {
CUDA_1D_KERNEL_LOOP(index, num_bbox1) {
int b1 = index;
int b2 = index;
int base1 = b1 * 4;
T b1_x1 = bbox1[base1];
T b1_y1 = bbox1[base1 + 1];
T b1_x2 = bbox1[base1 + 2];
T b1_y2 = bbox1[base1 + 3];
T b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset);
int base2 = b2 * 4;
T b2_x1 = bbox2[base2];
T b2_y1 = bbox2[base2 + 1];
T b2_x2 = bbox2[base2 + 2];
T b2_y2 = bbox2[base2 + 3];
T b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset);
T left = fmaxf(b1_x1, b2_x1), right = fminf(b1_x2, b2_x2);
T top = fmaxf(b1_y1, b2_y1), bottom = fminf(b1_y2, b2_y2);
T width = fmaxf(right - left + offset, 0.f);
T height = fmaxf(bottom - top + offset, 0.f);
T interS = width * height;
T baseS = 1.0;
if (mode == 0) {
baseS = fmaxf(b1_area + b2_area - interS, T(offset));
} else if (mode == 1) {
baseS = fmaxf(b1_area, T(offset));
}
const int b1 = index;
const int b2 = index;
const int base1 = b1 << 2; // b1 * 4
T b1_x1, b1_y1, b1_x2, b1_y2;
load_bbox<T>(bbox1, base1, b1_x1, b1_y1, b1_x2, b1_y2);
const T b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset);
const int base2 = b2 << 2; // b2 * 4
T b2_x1, b2_y1, b2_x2, b2_y2;
load_bbox<T>(bbox2, base2, b2_x1, b2_y1, b2_x2, b2_y2);
const T b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset);
const T left = fmaxf(b1_x1, b2_x1), right = fminf(b1_x2, b2_x2);
const T top = fmaxf(b1_y1, b2_y1), bottom = fminf(b1_y2, b2_y2);
const T width = fmaxf(right - left + offset, 0.f);
const T height = fmaxf(bottom - top + offset, 0.f);
const T interS = width * height;
const T baseS =
fmaxf(mode == 0 ? b1_area + b2_area - interS : b1_area, T(offset));
ious[index] = interS / baseS;
}
} else {
CUDA_1D_KERNEL_LOOP(index, num_bbox1 * num_bbox2) {
int b1 = index / num_bbox2;
int b2 = index % num_bbox2;
int base1 = b1 * 4;
T b1_x1 = bbox1[base1];
T b1_y1 = bbox1[base1 + 1];
T b1_x2 = bbox1[base1 + 2];
T b1_y2 = bbox1[base1 + 3];
T b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset);
int base2 = b2 * 4;
T b2_x1 = bbox2[base2];
T b2_y1 = bbox2[base2 + 1];
T b2_x2 = bbox2[base2 + 2];
T b2_y2 = bbox2[base2 + 3];
T b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset);
T left = fmaxf(b1_x1, b2_x1), right = fminf(b1_x2, b2_x2);
T top = fmaxf(b1_y1, b2_y1), bottom = fminf(b1_y2, b2_y2);
T width = fmaxf(right - left + offset, 0.f);
T height = fmaxf(bottom - top + offset, 0.f);
T interS = width * height;
T baseS = 1.0;
if (mode == 0) {
baseS = fmaxf(b1_area + b2_area - interS, T(offset));
} else if (mode == 1) {
baseS = fmaxf(b1_area, T(offset));
}
const int b1 = index / num_bbox2;
const int b2 = index % num_bbox2;
const int base1 = b1 << 2; // b1 * 4
T b1_x1, b1_y1, b1_x2, b1_y2;
load_bbox<T>(bbox1, base1, b1_x1, b1_y1, b1_x2, b1_y2);
const T b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset);
const int base2 = b2 << 2; // b2 * 4
T b2_x1, b2_y1, b2_x2, b2_y2;
load_bbox<T>(bbox2, base2, b2_x1, b2_y1, b2_x2, b2_y2);
const T b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset);
const T left = fmaxf(b1_x1, b2_x1), right = fminf(b1_x2, b2_x2);
const T top = fmaxf(b1_y1, b2_y1), bottom = fminf(b1_y2, b2_y2);
const T width = fmaxf(right - left + offset, 0.f);
const T height = fmaxf(bottom - top + offset, 0.f);
const T interS = width * height;
const T baseS =
fmaxf(mode == 0 ? b1_area + b2_area - interS : b1_area, T(offset));
ious[index] = interS / baseS;
}
}
}
__device__ __forceinline__ __half __half_area(const __half x1, const __half y1,
const __half x2, const __half y2,
const __half offset) {
const __half half_w = __hadd(__hsub(x2, x1), offset);
const __half half_h = __hadd(__hsub(y2, y1), offset);
return __hmul(half_w, half_h);
}
__device__ __forceinline__ __half __half_max(const __half a, const __half b) {
return __hge(a, b) ? a : b;
}
__device__ __forceinline__ __half __half_min(const __half a, const __half b) {
return __hle(a, b) ? a : b;
}
// fp16 won't provide much increase when aligned==true. It is useful when
// aligned==false, which would give you ~40% bonus.
__device__ void bbox_overlaps_cuda_kernel_half(
const __half* bbox1, const __half* bbox2, __half* ious, const int num_bbox1,
const int num_bbox2, const int mode, const bool aligned, const int offset) {
const int num_output = aligned ? num_bbox1 : num_bbox1 * num_bbox2;
const __half h_offset = __int2half_rn(offset);
CUDA_1D_KERNEL_LOOP(index, num_output) {
const int b1 = aligned ? index : index / num_bbox2;
const int b2 = aligned ? index : index % num_bbox2;
const int base1 = b1 << 2;
__half b1_x1, b1_y1, b1_x2, b1_y2;
load_bbox<__half>(bbox1, base1, b1_x1, b1_y1, b1_x2, b1_y2);
const __half b1_area = __half_area(b1_x1, b1_y1, b1_x2, b1_y2, h_offset);
const int base2 = b2 << 2;
__half b2_x1, b2_y1, b2_x2, b2_y2;
load_bbox<__half>(bbox2, base2, b2_x1, b2_y1, b2_x2, b2_y2);
const __half b2_area = __half_area(b2_x1, b2_y1, b2_x2, b2_y2, h_offset);
const __half left = __half_max(b1_x1, b2_x1),
right = __half_min(b1_x2, b2_x2);
const __half top = __half_max(b1_y1, b2_y1),
bottom = __half_min(b1_y2, b2_y2);
const __half width =
__half_max(__hadd(__hsub(right, left), h_offset), __float2half(0.f));
const __half height =
__half_max(__hadd(__hsub(bottom, top), h_offset), __float2half(0.f));
const __half interS = __hmul(width, height);
const __half baseS = __half_max(
mode == 0 ? __hsub(__hadd(b1_area, b2_area), interS) : b1_area,
h_offset);
ious[index] = __hdiv(interS, baseS);
}
}
#endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH
......@@ -2,6 +2,20 @@
#include "bbox_overlaps_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
// Disable fp16 on ROCm device
#ifndef HIP_DIFF
template <>
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
const at::Half* bbox1, const at::Half* bbox2, at::Half* ious,
const int num_bbox1, const int num_bbox2, const int mode,
const bool aligned, const int offset) {
bbox_overlaps_cuda_kernel_half(reinterpret_cast<const __half*>(bbox1),
reinterpret_cast<const __half*>(bbox2),
reinterpret_cast<__half*>(ious), num_bbox1,
num_bbox2, mode, aligned, offset);
}
#endif // HIP_DIFF
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode,
const bool aligned, const int offset) {
......
......@@ -8,7 +8,7 @@ from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
class TestBBox(object):
def _test_bbox_overlaps(self, device, dtype=torch.float):
def _test_bbox_overlaps(self, dtype=torch.float, device='cpu'):
from mmcv.ops import bbox_overlaps
b1 = torch.tensor([[1.0, 1.0, 3.0, 4.0], [2.0, 2.0, 3.0, 4.0],
[7.0, 7.0, 8.0, 8.0]]).to(device).type(dtype)
......@@ -35,6 +35,7 @@ class TestBBox(object):
assert np.allclose(out.cpu().numpy(), should_output, 1e-2)
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment