vision.cpp 5.07 KB
Newer Older
1
2
3
4
5
6
#include <Python.h>
#include <torch/script.h>

#ifdef WITH_CUDA
#include <cuda.h>
#endif
7
8
9
#ifdef WITH_HIP
#include <hip/hip_runtime.h>
#endif
10

11
#include "DeformConv.h"
12
13
#include "PSROIAlign.h"
#include "PSROIPool.h"
14
15
#include "ROIAlign.h"
#include "ROIPool.h"
eellison's avatar
eellison committed
16
#include "empty_tensor_op.h"
17
18
#include "nms.h"

19
// If we are in a Windows environment, we need to define
20
// initialization functions for the _custom_ops extension
21
#ifdef _WIN32
22
PyMODINIT_FUNC PyInit__C(void) {
23
24
25
26
  // No need to do anything.
  return NULL;
}
#endif
27

28
namespace vision {
29
int64_t cuda_version() {
30
#ifdef WITH_CUDA
31
32
33
  return CUDA_VERSION;
#else
  return -1;
34
#endif
35
}
36
} // namespace vision
37

38
TORCH_LIBRARY(torchvision, m) {
39
  m.def(
40
      "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
41
  m.def(
42
      "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
43
  m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
44
  m.def(
45
      "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
46
  m.def(
47
      "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
48
  m.def(
49
      "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
50
  m.def(
51
      "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
52
  m.def(
53
      "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
54
  m.def(
55
      "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
56
  m.def(
57
      "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
58
  m.def(
59
      "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
60
  m.def("_cuda_version", &vision::cuda_version);
61
  m.def("_new_empty_tensor_op", &new_empty_tensor);
62
63
64
}

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
65
66
  m.impl("deform_conv2d", DeformConv2d_forward_cpu);
  m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu);
67
  m.impl("nms", nms_cpu);
68
69
  m.impl("ps_roi_align", PSROIAlign_forward_cpu);
  m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu);
70
71
72
73
74
75
  m.impl("ps_roi_pool", PSROIPool_forward_cpu);
  m.impl("_ps_roi_pool_backward", PSROIPool_backward_cpu);
  m.impl("roi_align", ROIAlign_forward_cpu);
  m.impl("_roi_align_backward", ROIAlign_backward_cpu);
  m.impl("roi_pool", ROIPool_forward_cpu);
  m.impl("_roi_pool_backward", ROIPool_backward_cpu);
76
77
78
79
80
}

// TODO: Place this in a hypothetical separate torchvision_cuda library
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
81
82
  m.impl("deform_conv2d", DeformConv2d_forward_cuda);
  m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda);
83
  m.impl("nms", nms_cuda);
84
85
  m.impl("ps_roi_align", PSROIAlign_forward_cuda);
  m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda);
86
87
88
89
90
91
  m.impl("ps_roi_pool", PSROIPool_forward_cuda);
  m.impl("_ps_roi_pool_backward", PSROIPool_backward_cuda);
  m.impl("roi_align", ROIAlign_forward_cuda);
  m.impl("_roi_align_backward", ROIAlign_backward_cuda);
  m.impl("roi_pool", ROIPool_forward_cuda);
  m.impl("_roi_pool_backward", ROIPool_backward_cuda);
92
93
94
95
}
#endif

// Autocast only needs to wrap forward pass ops.
96
#if defined(WITH_CUDA) || defined(WITH_HIP)
97
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
98
  m.impl("deform_conv2d", DeformConv2d_autocast);
99
  m.impl("nms", nms_autocast);
100
  m.impl("ps_roi_align", PSROIAlign_autocast);
101
102
103
  m.impl("ps_roi_pool", PSROIPool_autocast);
  m.impl("roi_align", ROIAlign_autocast);
  m.impl("roi_pool", ROIPool_autocast);
104
105
106
107
}
#endif

TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
108
109
  m.impl("deform_conv2d", DeformConv2d_autograd);
  m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd);
110
111
  m.impl("ps_roi_align", PSROIAlign_autograd);
  m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd);
112
113
114
115
116
117
  m.impl("ps_roi_pool", PSROIPool_autograd);
  m.impl("_ps_roi_pool_backward", PSROIPool_backward_autograd);
  m.impl("roi_align", ROIAlign_autograd);
  m.impl("_roi_align_backward", ROIAlign_backward_autograd);
  m.impl("roi_pool", ROIPool_autograd);
  m.impl("_roi_pool_backward", ROIPool_backward_autograd);
118
}