vision.cpp 641 Bytes
Newer Older
1
2
3
4
#include "ROIAlign.h"
#include "ROIPool.h"
#include "nms.h"

5
6
7
8
#ifdef WITH_CUDA
#include <cuda.h>
#endif

9
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10
11
  // TODO: remove nms from here since it is now registered
  //       and used as a PyTorch custom op
12
13
14
15
16
  m.def("nms", &nms, "non-maximum suppression");
  m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
  m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
  m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
  m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
17
18
19
#ifdef WITH_CUDA
  m.attr("CUDA_VERSION") = CUDA_VERSION;
#endif
20
}