roi_align.cpp 3.87 KB
Newer Older
1
#include "roi_align.h"
2

3
4
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
5
#include <torch/types.h>
6

7
8
namespace vision {
namespace ops {
9
10

at::Tensor roi_align(
11
12
    const at::Tensor& input, // Input feature map.
    const at::Tensor& rois, // List of ROIs to pool over.
13
    double spatial_scale, // The scale of the image features. ROIs will be
14
    // scaled to this.
15
16
17
18
    int64_t pooled_height, // The height of the pooled feature map.
    int64_t pooled_width, // The width of the pooled feature
    int64_t sampling_ratio, // The number of points to sample in each bin
    bool aligned) // The flag for pixel shift
19
20
// along each axis.
{
Kai Zhang's avatar
Kai Zhang committed
21
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align");
22
23
24
25
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::roi_align", "")
                       .typed<decltype(roi_align)>();
  return op.call(
Francisco Massa's avatar
Francisco Massa committed
26
27
28
29
30
31
32
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      aligned);
33
34
}

Edward Z. Yang's avatar
Edward Z. Yang committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
at::Tensor roi_align_symint(
    const at::Tensor& input, // Input feature map.
    const at::Tensor& rois, // List of ROIs to pool over.
    double spatial_scale, // The scale of the image features. ROIs will be
    // scaled to this.
    c10::SymInt pooled_height, // The height of the pooled feature map.
    c10::SymInt pooled_width, // The width of the pooled feature
    int64_t sampling_ratio, // The number of points to sample in each bin
    bool aligned) // The flag for pixel shift
// along each axis.
{
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align");
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::roi_align", "")
                       .typed<decltype(roi_align_symint)>();
  return op.call(
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      aligned);
}

60
61
namespace detail {

62
at::Tensor _roi_align_backward(
63
64
    const at::Tensor& grad,
    const at::Tensor& rois,
65
66
67
68
69
70
71
72
73
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width,
    int64_t sampling_ratio,
    bool aligned) {
74
75
76
77
78
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_roi_align_backward", "")
          .typed<decltype(_roi_align_backward)>();
  return op.call(
79
80
81
82
83
84
85
86
87
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
AhnDW's avatar
AhnDW committed
88
89
      sampling_ratio,
      aligned);
90
}
91

Edward Z. Yang's avatar
Edward Z. Yang committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
at::Tensor _roi_align_backward_symint(
    const at::Tensor& grad,
    const at::Tensor& rois,
    double spatial_scale,
    c10::SymInt pooled_height,
    c10::SymInt pooled_width,
    c10::SymInt batch_size,
    c10::SymInt channels,
    c10::SymInt height,
    c10::SymInt width,
    int64_t sampling_ratio,
    bool aligned) {
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_roi_align_backward", "")
          .typed<decltype(_roi_align_backward_symint)>();
  return op.call(
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
      sampling_ratio,
      aligned);
}

122
123
} // namespace detail

124
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
125
  m.def(TORCH_SELECTIVE_SCHEMA(
Edward Z. Yang's avatar
Edward Z. Yang committed
126
      "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
127
  m.def(TORCH_SELECTIVE_SCHEMA(
Edward Z. Yang's avatar
Edward Z. Yang committed
128
      "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor"));
129
130
}

131
132
} // namespace ops
} // namespace vision