vision.cpp 3.9 KB
Newer Older
1
2
3
4
5
6
#include <Python.h>
#include <torch/script.h>

#ifdef WITH_CUDA
#include <cuda.h>
#endif
7
8
9
#ifdef WITH_HIP
#include <hip/hip_runtime.h>
#endif
10

11
#include "DeformConv.h"
12
13
#include "PSROIAlign.h"
#include "PSROIPool.h"
14
15
#include "ROIAlign.h"
#include "ROIPool.h"
eellison's avatar
eellison committed
16
#include "empty_tensor_op.h"
17
18
#include "nms.h"

19
// If we are in a Windows environment, we need to define
20
// initialization functions for the _C extension
21
22
#ifdef _WIN32
#if PY_MAJOR_VERSION < 3
23
PyMODINIT_FUNC init_C(void) {
24
  // No need to do anything.
25
  // extension.py will run on load
26
27
28
  return NULL;
}
#else
29
PyMODINIT_FUNC PyInit__C(void) {
30
  // No need to do anything.
31
  // extension.py will run on load
32
33
34
  return NULL;
}
#endif
35
36
#endif

37
38
namespace vision {
int64_t cuda_version() noexcept {
39
#ifdef WITH_CUDA
40
41
42
  return CUDA_VERSION;
#else
  return -1;
43
#endif
44
}
45
} // namespace vision
46

47
TORCH_LIBRARY(torchvision, m) {
48
  m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
49
50
51
52
53
54
  m.def(
      "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
  m.def(
      "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
  m.def("roi_pool", &roi_pool);
  m.def("_new_empty_tensor_op", &new_empty_tensor);
55
56
57
58
  m.def(
      "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
  m.def(
      "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
59
  m.def("ps_roi_pool", &ps_roi_pool);
60
61
62
63
  m.def(
      "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor");
  m.def(
      "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)");
64
  m.def("_cuda_version", &vision::cuda_version);
65
66
67
68
69
}

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
  m.impl("roi_align", ROIAlign_forward_cpu);
  m.impl("_roi_align_backward", ROIAlign_backward_cpu);
70
71
  m.impl("deform_conv2d", DeformConv2d_forward_cpu);
  m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu);
72
  m.impl("nms", nms_cpu);
73
74
  m.impl("ps_roi_align", PSROIAlign_forward_cpu);
  m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu);
75
76
77
78
79
80
81
}

// TODO: Place this in a hypothetical separate torchvision_cuda library
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
  m.impl("roi_align", ROIAlign_forward_cuda);
  m.impl("_roi_align_backward", ROIAlign_backward_cuda);
82
83
  m.impl("deform_conv2d", DeformConv2d_forward_cuda);
  m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda);
84
  m.impl("nms", nms_cuda);
85
86
  m.impl("ps_roi_align", PSROIAlign_forward_cuda);
  m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda);
87
88
89
90
}
#endif

// Autocast only needs to wrap forward pass ops.
91
#if defined(WITH_CUDA) || defined(WITH_HIP)
92
93
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
  m.impl("roi_align", ROIAlign_autocast);
94
  m.impl("deform_conv2d", DeformConv2d_autocast);
95
  m.impl("nms", nms_autocast);
96
  m.impl("ps_roi_align", PSROIAlign_autocast);
97
98
99
100
101
102
}
#endif

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
  m.impl("roi_align", ROIAlign_autograd);
  m.impl("_roi_align_backward", ROIAlign_backward_autograd);
103
104
  m.impl("deform_conv2d", DeformConv2d_autograd);
  m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd);
105
106
  m.impl("ps_roi_align", PSROIAlign_autograd);
  m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd);
107
}