Commit 06cbdb5b authored by Yuxin Wu's avatar Yuxin Wu Committed by Francisco Massa
Browse files

Speed up nms_cuda (#1704)

1. Let the IOU function compare with threshold. This avoid a division. Similar strategy is also used in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/non_max_suppression_op.cu.cc
2. Only compute the upper triangle of the mask.

This speeds up the kernel about 20% (tested on GTX 1080Ti, with 20 input cases dumped from a Mask R-CNN inference job).
parent 2a174229
......@@ -11,14 +11,14 @@
int const threadsPerBlock = sizeof(unsigned long long) * 8;
template <typename T>
__device__ inline float devIoU(T const* const a, T const* const b) {
__device__ inline bool devIoU(T const* const a, T const* const b, const float threshold) {
T left = max(a[0], b[0]), right = min(a[2], b[2]);
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
T interS = width * height;
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
return interS / (Sa + Sb - interS);
return interS > threshold * (Sa + Sb - interS);
}
template <typename T>
......@@ -30,7 +30,7 @@ __global__ void nms_kernel(
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
......@@ -60,7 +60,7 @@ __global__ void nms_kernel(
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU<T>(cur_box, block_boxes + i * 4) > iou_threshold) {
if (devIoU<T>(cur_box, block_boxes + i * 4, iou_threshold)) {
t |= 1ULL << i;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment