iou3d_kernel.cu 14.8 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
4
5
6
7
8
#include <stdio.h>
#define THREADS_PER_BLOCK 16
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

//#define DEBUG
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
const float EPS = 1e-8;
struct Point {
zhangwenwei's avatar
zhangwenwei committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
  float x, y;
  __device__ Point() {}
  __device__ Point(double _x, double _y) { x = _x, y = _y; }

  __device__ void set(float _x, float _y) {
    x = _x;
    y = _y;
  }

  __device__ Point operator+(const Point &b) const {
    return Point(x + b.x, y + b.y);
  }

  __device__ Point operator-(const Point &b) const {
    return Point(x - b.x, y - b.y);
  }
zhangwenwei's avatar
zhangwenwei committed
25
26
};

zhangwenwei's avatar
zhangwenwei committed
27
28
__device__ inline float cross(const Point &a, const Point &b) {
  return a.x * b.y - a.y * b.x;
zhangwenwei's avatar
zhangwenwei committed
29
30
}

zhangwenwei's avatar
zhangwenwei committed
31
32
33
__device__ inline float cross(const Point &p1, const Point &p2,
                              const Point &p0) {
  return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
zhangwenwei's avatar
zhangwenwei committed
34
35
}

zhangwenwei's avatar
zhangwenwei committed
36
37
38
39
40
41
42
__device__ int check_rect_cross(const Point &p1, const Point &p2,
                                const Point &q1, const Point &q2) {
  int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&
            min(q1.x, q2.x) <= max(p1.x, p2.x) &&
            min(p1.y, p2.y) <= max(q1.y, q2.y) &&
            min(q1.y, q2.y) <= max(p1.y, p2.y);
  return ret;
zhangwenwei's avatar
zhangwenwei committed
43
44
}

zhangwenwei's avatar
zhangwenwei committed
45
46
47
48
49
50
51
52
53
54
55
56
57
__device__ inline int check_in_box2d(const float *box, const Point &p) {
  // params: box (5) [x1, y1, x2, y2, angle]
  const float MARGIN = 1e-5;

  float center_x = (box[0] + box[2]) / 2;
  float center_y = (box[1] + box[3]) / 2;
  float angle_cos = cos(-box[4]),
        angle_sin =
            sin(-box[4]);  // rotate the point in the opposite direction of box
  float rot_x =
      (p.x - center_x) * angle_cos + (p.y - center_y) * angle_sin + center_x;
  float rot_y =
      -(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y;
zhangwenwei's avatar
zhangwenwei committed
58
#ifdef DEBUG
zhangwenwei's avatar
zhangwenwei committed
59
60
61
62
63
64
  printf("box: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", box[0], box[1], box[2],
         box[3], box[4]);
  printf(
      "center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, "
      "%.3f)\n",
      center_x, center_y, angle_cos, angle_sin, p.x, p.y, rot_x, rot_y);
zhangwenwei's avatar
zhangwenwei committed
65
#endif
zhangwenwei's avatar
zhangwenwei committed
66
67
  return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN &&
          rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN);
zhangwenwei's avatar
zhangwenwei committed
68
69
}

zhangwenwei's avatar
zhangwenwei committed
70
71
72
73
74
__device__ inline int intersection(const Point &p1, const Point &p0,
                                   const Point &q1, const Point &q0,
                                   Point &ans) {
  // fast exclusion
  if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
zhangwenwei's avatar
zhangwenwei committed
75

zhangwenwei's avatar
zhangwenwei committed
76
77
78
79
80
  // check cross standing
  float s1 = cross(q0, p1, p0);
  float s2 = cross(p1, q1, p0);
  float s3 = cross(p0, q1, q0);
  float s4 = cross(q1, p1, q0);
zhangwenwei's avatar
zhangwenwei committed
81

zhangwenwei's avatar
zhangwenwei committed
82
  if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0;
zhangwenwei's avatar
zhangwenwei committed
83

zhangwenwei's avatar
zhangwenwei committed
84
85
86
87
88
  // calculate intersection of two lines
  float s5 = cross(q1, p1, p0);
  if (fabs(s5 - s1) > EPS) {
    ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
    ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
zhangwenwei's avatar
zhangwenwei committed
89

zhangwenwei's avatar
zhangwenwei committed
90
91
92
93
  } else {
    float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
    float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
    float D = a0 * b1 - a1 * b0;
zhangwenwei's avatar
zhangwenwei committed
94

zhangwenwei's avatar
zhangwenwei committed
95
96
97
    ans.x = (b0 * c1 - b1 * c0) / D;
    ans.y = (a1 * c0 - a0 * c1) / D;
  }
zhangwenwei's avatar
zhangwenwei committed
98

zhangwenwei's avatar
zhangwenwei committed
99
  return 1;
zhangwenwei's avatar
zhangwenwei committed
100
101
}

zhangwenwei's avatar
zhangwenwei committed
102
103
104
105
106
107
108
109
__device__ inline void rotate_around_center(const Point &center,
                                            const float angle_cos,
                                            const float angle_sin, Point &p) {
  float new_x =
      (p.x - center.x) * angle_cos + (p.y - center.y) * angle_sin + center.x;
  float new_y =
      -(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
  p.set(new_x, new_y);
zhangwenwei's avatar
zhangwenwei committed
110
111
}

zhangwenwei's avatar
zhangwenwei committed
112
113
114
115
__device__ inline int point_cmp(const Point &a, const Point &b,
                                const Point &center) {
  return atan2(a.y - center.y, a.x - center.x) >
         atan2(b.y - center.y, b.x - center.x);
zhangwenwei's avatar
zhangwenwei committed
116
117
}

zhangwenwei's avatar
zhangwenwei committed
118
119
120
__device__ inline float box_overlap(const float *box_a, const float *box_b) {
  // params: box_a (5) [x1, y1, x2, y2, angle]
  // params: box_b (5) [x1, y1, x2, y2, angle]
zhangwenwei's avatar
zhangwenwei committed
121

zhangwenwei's avatar
zhangwenwei committed
122
123
124
125
  float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3],
        a_angle = box_a[4];
  float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3],
        b_angle = box_b[4];
zhangwenwei's avatar
zhangwenwei committed
126

zhangwenwei's avatar
zhangwenwei committed
127
128
  Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2);
  Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2);
zhangwenwei's avatar
zhangwenwei committed
129
#ifdef DEBUG
zhangwenwei's avatar
zhangwenwei committed
130
131
132
133
134
  printf(
      "a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n",
      a_x1, a_y1, a_x2, a_y2, a_angle, b_x1, b_y1, b_x2, b_y2, b_angle);
  printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y,
         center_b.x, center_b.y);
zhangwenwei's avatar
zhangwenwei committed
135
136
#endif

zhangwenwei's avatar
zhangwenwei committed
137
138
139
140
141
  Point box_a_corners[5];
  box_a_corners[0].set(a_x1, a_y1);
  box_a_corners[1].set(a_x2, a_y1);
  box_a_corners[2].set(a_x2, a_y2);
  box_a_corners[3].set(a_x1, a_y2);
zhangwenwei's avatar
zhangwenwei committed
142

zhangwenwei's avatar
zhangwenwei committed
143
144
145
146
147
  Point box_b_corners[5];
  box_b_corners[0].set(b_x1, b_y1);
  box_b_corners[1].set(b_x2, b_y1);
  box_b_corners[2].set(b_x2, b_y2);
  box_b_corners[3].set(b_x1, b_y2);
zhangwenwei's avatar
zhangwenwei committed
148

zhangwenwei's avatar
zhangwenwei committed
149
150
151
  // get oriented corners
  float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
  float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
zhangwenwei's avatar
zhangwenwei committed
152

zhangwenwei's avatar
zhangwenwei committed
153
  for (int k = 0; k < 4; k++) {
zhangwenwei's avatar
zhangwenwei committed
154
#ifdef DEBUG
zhangwenwei's avatar
zhangwenwei committed
155
156
157
    printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k,
           box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x,
           box_b_corners[k].y);
zhangwenwei's avatar
zhangwenwei committed
158
#endif
zhangwenwei's avatar
zhangwenwei committed
159
160
    rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
    rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
zhangwenwei's avatar
zhangwenwei committed
161
#ifdef DEBUG
zhangwenwei's avatar
zhangwenwei committed
162
163
    printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x,
           box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y);
zhangwenwei's avatar
zhangwenwei committed
164
#endif
zhangwenwei's avatar
zhangwenwei committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
  }

  box_a_corners[4] = box_a_corners[0];
  box_b_corners[4] = box_b_corners[0];

  // get intersection of lines
  Point cross_points[16];
  Point poly_center;
  int cnt = 0, flag = 0;

  poly_center.set(0, 0);
  for (int i = 0; i < 4; i++) {
    for (int j = 0; j < 4; j++) {
      flag = intersection(box_a_corners[i + 1], box_a_corners[i],
                          box_b_corners[j + 1], box_b_corners[j],
                          cross_points[cnt]);
      if (flag) {
        poly_center = poly_center + cross_points[cnt];
        cnt++;
      }
zhangwenwei's avatar
zhangwenwei committed
185
    }
zhangwenwei's avatar
zhangwenwei committed
186
187
188
189
190
191
192
193
  }

  // check corners
  for (int k = 0; k < 4; k++) {
    if (check_in_box2d(box_a, box_b_corners[k])) {
      poly_center = poly_center + box_b_corners[k];
      cross_points[cnt] = box_b_corners[k];
      cnt++;
zhangwenwei's avatar
zhangwenwei committed
194
    }
zhangwenwei's avatar
zhangwenwei committed
195
196
197
198
    if (check_in_box2d(box_b, box_a_corners[k])) {
      poly_center = poly_center + box_a_corners[k];
      cross_points[cnt] = box_a_corners[k];
      cnt++;
zhangwenwei's avatar
zhangwenwei committed
199
    }
zhangwenwei's avatar
zhangwenwei committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
  }

  poly_center.x /= cnt;
  poly_center.y /= cnt;

  // sort the points of polygon
  Point temp;
  for (int j = 0; j < cnt - 1; j++) {
    for (int i = 0; i < cnt - j - 1; i++) {
      if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {
        temp = cross_points[i];
        cross_points[i] = cross_points[i + 1];
        cross_points[i + 1] = temp;
      }
zhangwenwei's avatar
zhangwenwei committed
214
    }
zhangwenwei's avatar
zhangwenwei committed
215
  }
zhangwenwei's avatar
zhangwenwei committed
216
217

#ifdef DEBUG
zhangwenwei's avatar
zhangwenwei committed
218
219
220
221
222
  printf("cnt=%d\n", cnt);
  for (int i = 0; i < cnt; i++) {
    printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x,
           cross_points[i].y);
  }
zhangwenwei's avatar
zhangwenwei committed
223
224
#endif

zhangwenwei's avatar
zhangwenwei committed
225
226
227
228
229
230
  // get the overlap areas
  float area = 0;
  for (int k = 0; k < cnt - 1; k++) {
    area += cross(cross_points[k] - cross_points[0],
                  cross_points[k + 1] - cross_points[0]);
  }
zhangwenwei's avatar
zhangwenwei committed
231

zhangwenwei's avatar
zhangwenwei committed
232
  return fabs(area) / 2.0;
zhangwenwei's avatar
zhangwenwei committed
233
234
}

zhangwenwei's avatar
zhangwenwei committed
235
236
237
238
239
240
241
__device__ inline float iou_bev(const float *box_a, const float *box_b) {
  // params: box_a (5) [x1, y1, x2, y2, angle]
  // params: box_b (5) [x1, y1, x2, y2, angle]
  float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]);
  float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]);
  float s_overlap = box_overlap(box_a, box_b);
  return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
zhangwenwei's avatar
zhangwenwei committed
242
243
}

zhangwenwei's avatar
zhangwenwei committed
244
245
246
247
248
249
250
251
252
253
254
255
256
__global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a,
                                     const int num_b, const float *boxes_b,
                                     float *ans_overlap) {
  const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
  const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;

  if (a_idx >= num_a || b_idx >= num_b) {
    return;
  }
  const float *cur_box_a = boxes_a + a_idx * 5;
  const float *cur_box_b = boxes_b + b_idx * 5;
  float s_overlap = box_overlap(cur_box_a, cur_box_b);
  ans_overlap[a_idx * num_b + b_idx] = s_overlap;
zhangwenwei's avatar
zhangwenwei committed
257
258
}

zhangwenwei's avatar
zhangwenwei committed
259
260
261
262
263
__global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a,
                                     const int num_b, const float *boxes_b,
                                     float *ans_iou) {
  const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
  const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
zhangwenwei's avatar
zhangwenwei committed
264

zhangwenwei's avatar
zhangwenwei committed
265
266
267
  if (a_idx >= num_a || b_idx >= num_b) {
    return;
  }
zhangwenwei's avatar
zhangwenwei committed
268

zhangwenwei's avatar
zhangwenwei committed
269
270
271
272
  const float *cur_box_a = boxes_a + a_idx * 5;
  const float *cur_box_b = boxes_b + b_idx * 5;
  float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
  ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
zhangwenwei's avatar
zhangwenwei committed
273
274
275
}

__global__ void nms_kernel(const int boxes_num, const float nms_overlap_thresh,
zhangwenwei's avatar
zhangwenwei committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
                           const float *boxes, unsigned long long *mask) {
  // params: boxes (N, 5) [x1, y1, x2, y2, ry]
  // params: mask (N, N/THREADS_PER_BLOCK_NMS)

  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

  // if (row_start > col_start) return;

  const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
                             THREADS_PER_BLOCK_NMS);
  const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
                             THREADS_PER_BLOCK_NMS);

  __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];

  if (threadIdx.x < col_size) {
    block_boxes[threadIdx.x * 5 + 0] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
    block_boxes[threadIdx.x * 5 + 1] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
    block_boxes[threadIdx.x * 5 + 2] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
    block_boxes[threadIdx.x * 5 + 3] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
    block_boxes[threadIdx.x * 5 + 4] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
    const float *cur_box = boxes + cur_box_idx * 5;

    int i = 0;
    unsigned long long t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
zhangwenwei's avatar
zhangwenwei committed
315
    }
zhangwenwei's avatar
zhangwenwei committed
316
317
318
319
    for (i = start; i < col_size; i++) {
      if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
        t |= 1ULL << i;
      }
zhangwenwei's avatar
zhangwenwei committed
320
    }
zhangwenwei's avatar
zhangwenwei committed
321
322
323
    const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
    mask[cur_box_idx * col_blocks + col_start] = t;
  }
zhangwenwei's avatar
zhangwenwei committed
324
325
}

zhangwenwei's avatar
zhangwenwei committed
326
327
328
329
330
331
332
333
__device__ inline float iou_normal(float const *const a, float const *const b) {
  float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
  float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
  float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);
  float interS = width * height;
  float Sa = (a[2] - a[0]) * (a[3] - a[1]);
  float Sb = (b[2] - b[0]) * (b[3] - b[1]);
  return interS / fmaxf(Sa + Sb - interS, EPS);
zhangwenwei's avatar
zhangwenwei committed
334
335
}

zhangwenwei's avatar
zhangwenwei committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
__global__ void nms_normal_kernel(const int boxes_num,
                                  const float nms_overlap_thresh,
                                  const float *boxes,
                                  unsigned long long *mask) {
  // params: boxes (N, 5) [x1, y1, x2, y2, ry]
  // params: mask (N, N/THREADS_PER_BLOCK_NMS)

  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

  // if (row_start > col_start) return;

  const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
                             THREADS_PER_BLOCK_NMS);
  const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
                             THREADS_PER_BLOCK_NMS);

  __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];

  if (threadIdx.x < col_size) {
    block_boxes[threadIdx.x * 5 + 0] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
    block_boxes[threadIdx.x * 5 + 1] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
    block_boxes[threadIdx.x * 5 + 2] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
    block_boxes[threadIdx.x * 5 + 3] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
    block_boxes[threadIdx.x * 5 + 4] =
        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
    const float *cur_box = boxes + cur_box_idx * 5;

    int i = 0;
    unsigned long long t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
zhangwenwei's avatar
zhangwenwei committed
378
    }
zhangwenwei's avatar
zhangwenwei committed
379
380
381
382
    for (i = start; i < col_size; i++) {
      if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
        t |= 1ULL << i;
      }
zhangwenwei's avatar
zhangwenwei committed
383
    }
zhangwenwei's avatar
zhangwenwei committed
384
385
386
    const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
    mask[cur_box_idx * col_blocks + col_start] = t;
  }
zhangwenwei's avatar
zhangwenwei committed
387
388
}

zhangwenwei's avatar
zhangwenwei committed
389
390
391
392
393
394
395
void boxesoverlapLauncher(const int num_a, const float *boxes_a,
                          const int num_b, const float *boxes_b,
                          float *ans_overlap) {
  dim3 blocks(
      DIVUP(num_b, THREADS_PER_BLOCK),
      DIVUP(num_a, THREADS_PER_BLOCK));  // blockIdx.x(col), blockIdx.y(row)
  dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
zhangwenwei's avatar
zhangwenwei committed
396

zhangwenwei's avatar
zhangwenwei committed
397
398
  boxes_overlap_kernel<<<blocks, threads>>>(num_a, boxes_a, num_b, boxes_b,
                                            ans_overlap);
zhangwenwei's avatar
zhangwenwei committed
399
#ifdef DEBUG
zhangwenwei's avatar
zhangwenwei committed
400
  cudaDeviceSynchronize();  // for using printf in kernel function
zhangwenwei's avatar
zhangwenwei committed
401
402
403
#endif
}

zhangwenwei's avatar
zhangwenwei committed
404
405
406
407
408
409
void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b,
                         const float *boxes_b, float *ans_iou) {
  dim3 blocks(
      DIVUP(num_b, THREADS_PER_BLOCK),
      DIVUP(num_a, THREADS_PER_BLOCK));  // blockIdx.x(col), blockIdx.y(row)
  dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
zhangwenwei's avatar
zhangwenwei committed
410

zhangwenwei's avatar
zhangwenwei committed
411
412
  boxes_iou_bev_kernel<<<blocks, threads>>>(num_a, boxes_a, num_b, boxes_b,
                                            ans_iou);
zhangwenwei's avatar
zhangwenwei committed
413
414
}

zhangwenwei's avatar
zhangwenwei committed
415
416
417
418
419
420
void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num,
                 float nms_overlap_thresh) {
  dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
              DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
  dim3 threads(THREADS_PER_BLOCK_NMS);
  nms_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes, mask);
zhangwenwei's avatar
zhangwenwei committed
421
422
}

zhangwenwei's avatar
zhangwenwei committed
423
424
425
426
427
428
429
void nmsNormalLauncher(const float *boxes, unsigned long long *mask,
                       int boxes_num, float nms_overlap_thresh) {
  dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
              DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
  dim3 threads(THREADS_PER_BLOCK_NMS);
  nms_normal_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes,
                                         mask);
zhangwenwei's avatar
zhangwenwei committed
430
}