nms_cpu.cpp 2.54 KB
Newer Older
gslotman's avatar
gslotman committed
1
#include "vision_cpu.h"
2
3
4
5
6

template <typename scalar_t>
at::Tensor nms_cpu_kernel(
    const at::Tensor& dets,
    const at::Tensor& scores,
7
    const float iou_threshold) {
8
  AT_ASSERTM(!dets.options().device().is_cuda(), "dets must be a CPU tensor");
9
  AT_ASSERTM(
Francisco Massa's avatar
Francisco Massa committed
10
11
12
13
      !scores.options().device().is_cuda(), "scores must be a CPU tensor");
  AT_ASSERTM(
      dets.scalar_type() == scores.scalar_type(),
      "dets should have the same type as scores");
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

  if (dets.numel() == 0)
    return at::empty({0}, dets.options().dtype(at::kLong));

  auto x1_t = dets.select(1, 0).contiguous();
  auto y1_t = dets.select(1, 1).contiguous();
  auto x2_t = dets.select(1, 2).contiguous();
  auto y2_t = dets.select(1, 3).contiguous();

  at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);

  auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));

  auto ndets = dets.size(0);
  at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
  at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));

31
32
33
34
35
36
37
38
  auto suppressed = suppressed_t.data_ptr<uint8_t>();
  auto keep = keep_t.data_ptr<int64_t>();
  auto order = order_t.data_ptr<int64_t>();
  auto x1 = x1_t.data_ptr<scalar_t>();
  auto y1 = y1_t.data_ptr<scalar_t>();
  auto x2 = x2_t.data_ptr<scalar_t>();
  auto y2 = y2_t.data_ptr<scalar_t>();
  auto areas = areas_t.data_ptr<scalar_t>();
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

  int64_t num_to_keep = 0;

  for (int64_t _i = 0; _i < ndets; _i++) {
    auto i = order[_i];
    if (suppressed[i] == 1)
      continue;
    keep[num_to_keep++] = i;
    auto ix1 = x1[i];
    auto iy1 = y1[i];
    auto ix2 = x2[i];
    auto iy2 = y2[i];
    auto iarea = areas[i];

    for (int64_t _j = _i + 1; _j < ndets; _j++) {
      auto j = order[_j];
      if (suppressed[j] == 1)
        continue;
      auto xx1 = std::max(ix1, x1[j]);
      auto yy1 = std::max(iy1, y1[j]);
      auto xx2 = std::min(ix2, x2[j]);
      auto yy2 = std::min(iy2, y2[j]);

      auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
      auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
      auto inter = w * h;
      auto ovr = inter / (iarea + areas[j] - inter);
66
      if (ovr > iou_threshold)
67
68
69
70
71
72
73
74
75
        suppressed[j] = 1;
    }
  }
  return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}

at::Tensor nms_cpu(
    const at::Tensor& dets,
    const at::Tensor& scores,
76
    const float iou_threshold) {
77
78
  auto result = at::empty({0}, dets.options());

79
  AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
80
    result = nms_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
81
82
83
  });
  return result;
}