vision.cpp 5.18 KB
Newer Older
1
2
#include "vision.h"

3
4
5
6
7
8
#include <Python.h>
#include <torch/script.h>

#ifdef WITH_CUDA
#include <cuda.h>
#endif
9
10
11
#ifdef WITH_HIP
#include <hip/hip_runtime.h>
#endif
12

13
14
#include "deform_conv2d.h"
#include "new_empty_tensor_op.h"
15
#include "nms.h"
16
17
18
19
#include "ps_roi_align.h"
#include "ps_roi_pool.h"
#include "roi_align.h"
#include "roi_pool.h"
20

21
// If we are in a Windows environment, we need to define
22
// initialization functions for the _custom_ops extension
23
#ifdef _WIN32
24
PyMODINIT_FUNC PyInit__C(void) {
25
26
27
28
  // No need to do anything.
  return NULL;
}
#endif
29

30
namespace vision {
31
int64_t cuda_version() {
32
#ifdef WITH_CUDA
33
34
35
  return CUDA_VERSION;
#else
  return -1;
36
#endif
37
}
38
} // namespace vision
39

40
41
using namespace vision::ops;

42
TORCH_LIBRARY(torchvision, m) {
43
  m.def(
44
      "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
45
  m.def(
46
      "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
47
  m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
48
  m.def(
49
      "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
50
  m.def(
51
      "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
52
  m.def(
53
      "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
54
  m.def(
55
      "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
56
  m.def(
57
      "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
58
  m.def(
59
      "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
60
  m.def(
61
      "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
62
  m.def(
63
      "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
64
  m.def("_cuda_version", &vision::cuda_version);
65
  m.def("_new_empty_tensor_op", &new_empty_tensor);
66
67
68
}

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
69
70
  m.impl("deform_conv2d", deform_conv2d_forward_cpu);
  m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu);
71
  m.impl("nms", nms_cpu);
72
73
74
75
76
77
78
79
  m.impl("ps_roi_align", ps_roi_align_forward_cpu);
  m.impl("_ps_roi_align_backward", ps_roi_align_backward_cpu);
  m.impl("ps_roi_pool", ps_roi_pool_forward_cpu);
  m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cpu);
  m.impl("roi_align", roi_align_forward_cpu);
  m.impl("_roi_align_backward", roi_align_backward_cpu);
  m.impl("roi_pool", roi_pool_forward_cpu);
  m.impl("_roi_pool_backward", roi_pool_backward_cpu);
80
81
82
83
84
}

// TODO: Place this in a hypothetical separate torchvision_cuda library
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
85
86
  m.impl("deform_conv2d", deform_conv2d_forward_cuda);
  m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda);
87
  m.impl("nms", nms_cuda);
88
89
90
91
92
93
94
95
  m.impl("ps_roi_align", ps_roi_align_forward_cuda);
  m.impl("_ps_roi_align_backward", ps_roi_align_backward_cuda);
  m.impl("ps_roi_pool", ps_roi_pool_forward_cuda);
  m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cuda);
  m.impl("roi_align", roi_align_forward_cuda);
  m.impl("_roi_align_backward", roi_align_backward_cuda);
  m.impl("roi_pool", roi_pool_forward_cuda);
  m.impl("_roi_pool_backward", roi_pool_backward_cuda);
96
97
98
99
}
#endif

// Autocast only needs to wrap forward pass ops.
100
#if defined(WITH_CUDA) || defined(WITH_HIP)
101
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
102
  m.impl("deform_conv2d", deform_conv2d_autocast);
103
  m.impl("nms", nms_autocast);
104
105
106
107
  m.impl("ps_roi_align", ps_roi_align_autocast);
  m.impl("ps_roi_pool", ps_roi_pool_autocast);
  m.impl("roi_align", roi_align_autocast);
  m.impl("roi_pool", roi_pool_autocast);
108
109
110
111
}
#endif

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
112
113
114
115
116
117
118
119
120
121
  m.impl("deform_conv2d", deform_conv2d_autograd);
  m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
  m.impl("ps_roi_align", ps_roi_align_autograd);
  m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
  m.impl("ps_roi_pool", ps_roi_pool_autograd);
  m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
  m.impl("roi_align", roi_align_autograd);
  m.impl("_roi_align_backward", roi_align_backward_autograd);
  m.impl("roi_pool", roi_pool_autograd);
  m.impl("_roi_pool_backward", roi_pool_backward_autograd);
122
}