"vscode:/vscode.git/clone" did not exist on "3c363d5709143f1e3a210f9e54bd80a7356d9e14"
nms.h 600 Bytes
Newer Older
1
#pragma once
2
#include "cpu/vision_cpu.h"
3
4

#ifdef WITH_CUDA
5
#include "cuda/vision_cuda.h"
6
7
8
9
10
#endif

at::Tensor nms(
    const at::Tensor& dets,
    const at::Tensor& scores,
11
    const float iou_threshold) {
12
13
14
15
16
17
  if (dets.device().is_cuda()) {
#ifdef WITH_CUDA
    if (dets.numel() == 0) {
      at::cuda::CUDAGuard device_guard(dets.device());
      return at::empty({0}, dets.options().dtype(at::kLong));
    }
18
    return nms_cuda(dets, scores, iou_threshold);
19
20
21
22
23
#else
    AT_ERROR("Not compiled with GPU support");
#endif
  }

24
  at::Tensor result = nms_cpu(dets, scores, iou_threshold);
25
26
  return result;
}