nms_kernel.cu 5.11 KB
Newer Older
1
#include <ATen/ATen.h>
2
3
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
4
#include <torch/library.h>
5
6
7

#include "cuda_helpers.h"

8
9
10
11
namespace vision {
namespace ops {

namespace {
12
13
14
15

int const threadsPerBlock = sizeof(unsigned long long) * 8;

template <typename T>
16
17
18
19
__device__ inline bool devIoU(
    T const* const a,
    T const* const b,
    const float threshold) {
20
21
22
23
24
25
  T left = max(a[0], b[0]), right = min(a[2], b[2]);
  T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
  T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
  T interS = width * height;
  T Sa = (a[2] - a[0]) * (a[3] - a[1]);
  T Sb = (b[2] - b[0]) * (b[3] - b[1]);
26
  return (interS / (Sa + Sb - interS)) > threshold;
27
28
29
}

template <typename T>
30
__global__ void nms_kernel_impl(
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
31
32
    int n_boxes,
    double iou_threshold,
33
34
35
36
37
    const T* dev_boxes,
    unsigned long long* dev_mask) {
  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

38
39
  if (row_start > col_start)
    return;
40
41
42
43
44
45

  const int row_size =
      min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
  const int col_size =
      min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);

46
  __shared__ T block_boxes[threadsPerBlock * 4];
47
  if (threadIdx.x < col_size) {
48
49
50
51
52
53
54
55
    block_boxes[threadIdx.x * 4 + 0] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];
    block_boxes[threadIdx.x * 4 + 1] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];
    block_boxes[threadIdx.x * 4 + 2] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];
    block_boxes[threadIdx.x * 4 + 3] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];
56
57
58
59
60
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
61
    const T* cur_box = dev_boxes + cur_box_idx * 4;
62
63
64
65
66
67
68
    int i = 0;
    unsigned long long t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
    }
    for (i = start; i < col_size; i++) {
Yuxin Wu's avatar
Yuxin Wu committed
69
      if (devIoU<T>(cur_box, block_boxes + i * 4, iou_threshold)) {
70
71
72
        t |= 1ULL << i;
      }
    }
73
    const int col_blocks = ceil_div(n_boxes, threadsPerBlock);
74
75
76
77
    dev_mask[cur_box_idx * col_blocks + col_start] = t;
  }
}

78
at::Tensor nms_kernel(
79
    const at::Tensor& dets,
80
    const at::Tensor& scores,
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
81
    double iou_threshold) {
vfdev's avatar
vfdev committed
82
83
  TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor");
  TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor");
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

  TORCH_CHECK(
      dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
  TORCH_CHECK(
      dets.size(1) == 4,
      "boxes should have 4 elements in dimension 1, got ",
      dets.size(1));
  TORCH_CHECK(
      scores.dim() == 1,
      "scores should be a 1d tensor, got ",
      scores.dim(),
      "D");
  TORCH_CHECK(
      dets.size(0) == scores.size(0),
      "boxes and scores should have same number of elements in ",
      "dimension 0, got ",
      dets.size(0),
      " and ",
      scores.size(0))

104
#if defined(WITH_CUDA) || defined(WITH_HIP)
105
  at::cuda::CUDAGuard device_guard(dets.device());
106
#else
vfdev's avatar
vfdev committed
107
  TORCH_CHECK(false, "Not compiled with GPU support");
108
109
110
111
112
#endif

  if (dets.numel() == 0) {
    return at::empty({0}, dets.options().dtype(at::kLong));
  }
113
114

  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
115
  auto dets_sorted = dets.index_select(0, order_t).contiguous();
116

117
  int dets_num = dets.size(0);
118

119
  const int col_blocks = ceil_div(dets_num, threadsPerBlock);
120
121

  at::Tensor mask =
122
      at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
123
124
125
126
127
128

  dim3 blocks(col_blocks, col_blocks);
  dim3 threads(threadsPerBlock);
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
129
      dets_sorted.scalar_type(), "nms_kernel", [&] {
130
        nms_kernel_impl<scalar_t><<<blocks, threads, 0, stream>>>(
131
132
            dets_num,
            iou_threshold,
133
134
            dets_sorted.data_ptr<scalar_t>(),
            (unsigned long long*)mask.data_ptr<int64_t>());
135
136
137
      });

  at::Tensor mask_cpu = mask.to(at::kCPU);
138
139
  unsigned long long* mask_host =
      (unsigned long long*)mask_cpu.data_ptr<int64_t>();
140
141
142
143
144

  std::vector<unsigned long long> remv(col_blocks);
  memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);

  at::Tensor keep =
145
      at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
146
  int64_t* keep_out = keep.data_ptr<int64_t>();
147
148

  int num_to_keep = 0;
149
  for (int i = 0; i < dets_num; i++) {
150
151
152
153
154
155
156
157
158
159
160
161
162
    int nblock = i / threadsPerBlock;
    int inblock = i % threadsPerBlock;

    if (!(remv[nblock] & (1ULL << inblock))) {
      keep_out[num_to_keep++] = i;
      unsigned long long* p = mask_host + i * col_blocks;
      for (int j = nblock; j < col_blocks; j++) {
        remv[j] |= p[j];
      }
    }
  }

  AT_CUDA_CHECK(cudaGetLastError());
Francisco Massa's avatar
Francisco Massa committed
163
164
165
  return order_t.index(
      {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
           .to(order_t.device(), keep.scalar_type())});
166
}
167

168
169
170
} // namespace

TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
171
  m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
172
173
}

174
175
} // namespace ops
} // namespace vision