nms.h 1.38 KB
Newer Older
1
#pragma once
2
#include "cpu/vision_cpu.h"
3
4

#ifdef WITH_CUDA
5
#include "cuda/vision_cuda.h"
6
#endif
7
8
9
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif
10
11
12
13

at::Tensor nms(
    const at::Tensor& dets,
    const at::Tensor& scores,
14
    const double iou_threshold) {
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  TORCH_CHECK(
      dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
  TORCH_CHECK(
      dets.size(1) == 4,
      "boxes should have 4 elements in dimension 1, got ",
      dets.size(1));
  TORCH_CHECK(
      scores.dim() == 1,
      "scores should be a 1d tensor, got ",
      scores.dim(),
      "D");
  TORCH_CHECK(
      dets.size(0) == scores.size(0),
      "boxes and scores should have same number of elements in ",
      "dimension 0, got ",
      dets.size(0),
      " and ",
      scores.size(0));
33
  if (dets.is_cuda()) {
34
#if defined(WITH_CUDA)
35
36
37
38
    if (dets.numel() == 0) {
      at::cuda::CUDAGuard device_guard(dets.device());
      return at::empty({0}, dets.options().dtype(at::kLong));
    }
39
    return nms_cuda(dets, scores, iou_threshold);
40
41
42
43
44
45
#elif defined(WITH_HIP)
    if (dets.numel() == 0) {
      at::cuda::HIPGuard device_guard(dets.device());
      return at::empty({0}, dets.options().dtype(at::kLong));
    }
    return nms_cuda(dets, scores, iou_threshold);
46
47
48
49
50
#else
    AT_ERROR("Not compiled with GPU support");
#endif
  }

51
  at::Tensor result = nms_cpu(dets, scores, iou_threshold);
52
53
  return result;
}