nms_cuda.cu 4.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>

#include "cuda_helpers.h"

#include <iostream>
#include <vector>

int const threadsPerBlock = sizeof(unsigned long long) * 8;

template <typename T>
__device__ inline float devIoU(T const* const a, T const* const b) {
  T left = max(a[0], b[0]), right = min(a[2], b[2]);
  T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
  T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
  T interS = width * height;
  T Sa = (a[2] - a[0]) * (a[3] - a[1]);
  T Sb = (b[2] - b[0]) * (b[3] - b[1]);
  return interS / (Sa + Sb - interS);
}

template <typename T>
__global__ void nms_kernel(
    const int n_boxes,
27
    const float iou_threshold,
28
29
30
31
32
33
34
35
36
37
38
39
    const T* dev_boxes,
    unsigned long long* dev_mask) {
  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

  // if (row_start > col_start) return;

  const int row_size =
      min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
  const int col_size =
      min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);

40
  __shared__ T block_boxes[threadsPerBlock * 4];
41
  if (threadIdx.x < col_size) {
42
43
44
45
46
47
48
49
    block_boxes[threadIdx.x * 4 + 0] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];
    block_boxes[threadIdx.x * 4 + 1] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];
    block_boxes[threadIdx.x * 4 + 2] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];
    block_boxes[threadIdx.x * 4 + 3] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];
50
51
52
53
54
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
55
    const T* cur_box = dev_boxes + cur_box_idx * 4;
56
57
58
59
60
61
62
    int i = 0;
    unsigned long long t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
    }
    for (i = start; i < col_size; i++) {
63
      if (devIoU<T>(cur_box, block_boxes + i * 4) > iou_threshold) {
64
65
66
67
68
69
70
71
        t |= 1ULL << i;
      }
    }
    const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);
    dev_mask[cur_box_idx * col_blocks + col_start] = t;
  }
}

72
73
74
at::Tensor nms_cuda(const at::Tensor& dets,
    const at::Tensor& scores,
    float iou_threshold) {
75
  using scalar_t = float;
76
77
78
  AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");
  AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");
  at::cuda::CUDAGuard device_guard(dets.device());
79
80

  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
81
  auto dets_sorted = dets.index_select(0, order_t);
82

83
  int dets_num = dets.size(0);
84

85
  const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);
86
87

  at::Tensor mask =
88
      at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
89
90
91
92
93
94

  dim3 blocks(col_blocks, col_blocks);
  dim3 threads(threadsPerBlock);
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
95
      dets_sorted.type(), "nms_kernel_cuda", [&] {
96
        nms_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
97
98
            dets_num,
            iou_threshold,
99
100
            dets_sorted.data_ptr<scalar_t>(),
            (unsigned long long*)mask.data_ptr<int64_t>());
101
102
103
      });

  at::Tensor mask_cpu = mask.to(at::kCPU);
104
  unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();
105
106
107
108
109

  std::vector<unsigned long long> remv(col_blocks);
  memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

  at::Tensor keep =
110
      at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
111
  int64_t* keep_out = keep.data_ptr<int64_t>();
112
113

  int num_to_keep = 0;
114
  for (int i = 0; i < dets_num; i++) {
115
116
117
118
119
120
121
122
123
124
125
126
127
    int nblock = i / threadsPerBlock;
    int inblock = i % threadsPerBlock;

    if (!(remv[nblock] & (1ULL << inblock))) {
      keep_out[num_to_keep++] = i;
      unsigned long long* p = mask_host + i * col_blocks;
      for (int j = nblock; j < col_blocks; j++) {
        remv[j] |= p[j];
      }
    }
  }

  AT_CUDA_CHECK(cudaGetLastError());
Francisco Massa's avatar
Francisco Massa committed
128
129
130
  return order_t.index(
      {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
           .to(order_t.device(), keep.scalar_type())});
131
}