ROIAlign.h 4.11 KB
Newer Older
1
2
#pragma once

3
#include "cpu/vision_cpu.h"
4
5

#ifdef WITH_CUDA
6
#include "cuda/vision_cuda.h"
7
8
9
10
11
12
#endif

// Interface for Python
at::Tensor ROIAlign_forward(
    const at::Tensor& input, // Input feature map.
    const at::Tensor& rois, // List of ROIs to pool over.
13
    const double spatial_scale, // The scale of the image features. ROIs will be
14
    // scaled to this.
15
16
    const int64_t pooled_height, // The height of the pooled feature map.
    const int64_t pooled_width, // The width of the pooled feature
AhnDW's avatar
AhnDW committed
17
18
    const int64_t sampling_ratio, // The number of points to sample in each bin
    const bool aligned) // The flag for pixel shift
19
20
21
22
23
24
25
26
27
28
// along each axis.
{
  if (input.type().is_cuda()) {
#ifdef WITH_CUDA
    return ROIAlign_forward_cuda(
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
AhnDW's avatar
AhnDW committed
29
30
        sampling_ratio,
        aligned);
31
32
33
34
35
#else
    AT_ERROR("Not compiled with GPU support");
#endif
  }
  return ROIAlign_forward_cpu(
AhnDW's avatar
AhnDW committed
36
      input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned);
37
38
39
40
41
42
43
44
45
46
47
48
}

at::Tensor ROIAlign_backward(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const float spatial_scale,
    const int pooled_height,
    const int pooled_width,
    const int batch_size,
    const int channels,
    const int height,
    const int width,
AhnDW's avatar
AhnDW committed
49
50
    const int sampling_ratio,
    const bool aligned) {
51
52
53
54
55
56
57
58
59
60
61
62
  if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
    return ROIAlign_backward_cuda(
        grad,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        batch_size,
        channels,
        height,
        width,
AhnDW's avatar
AhnDW committed
63
64
        sampling_ratio,
        aligned);
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#else
    AT_ERROR("Not compiled with GPU support");
#endif
  }
  return ROIAlign_backward_cpu(
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
AhnDW's avatar
AhnDW committed
79
80
      sampling_ratio,
      aligned);
81
}
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
 public:
  static variable_list forward(
      AutogradContext* ctx,
      Variable input,
      Variable rois,
      const double spatial_scale,
      const int64_t pooled_height,
      const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
98
99
      const int64_t sampling_ratio,
      const bool aligned) {
100
101
102
103
    ctx->saved_data["spatial_scale"] = spatial_scale;
    ctx->saved_data["pooled_height"] = pooled_height;
    ctx->saved_data["pooled_width"] = pooled_width;
    ctx->saved_data["sampling_ratio"] = sampling_ratio;
AhnDW's avatar
AhnDW committed
104
    ctx->saved_data["aligned"] = aligned;
105
106
107
108
109
110
111
112
    ctx->saved_data["input_shape"] = input.sizes();
    ctx->save_for_backward({rois});
    auto result = ROIAlign_forward(
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
AhnDW's avatar
AhnDW committed
113
114
        sampling_ratio,
        aligned);
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    return {result};
  }

  static variable_list backward(
      AutogradContext* ctx,
      variable_list grad_output) {
    // Use data saved in forward
    auto saved = ctx->get_saved_variables();
    auto rois = saved[0];
    auto input_shape = ctx->saved_data["input_shape"].toIntList();
    auto grad_in = ROIAlign_backward(
        grad_output[0],
        rois,
        ctx->saved_data["spatial_scale"].toDouble(),
        ctx->saved_data["pooled_height"].toInt(),
        ctx->saved_data["pooled_width"].toInt(),
        input_shape[0],
        input_shape[1],
        input_shape[2],
        input_shape[3],
AhnDW's avatar
AhnDW committed
135
136
        ctx->saved_data["sampling_ratio"].toInt(),
        ctx->saved_data["aligned"].toBool());
137
    return {
AhnDW's avatar
AhnDW committed
138
        grad_in, Variable(), Variable(), Variable(), Variable(), Variable(), Variable()};
139
140
141
142
143
144
145
146
147
  }
};

Tensor roi_align(
    const Tensor& input,
    const Tensor& rois,
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
148
149
    const int64_t sampling_ratio,
    const bool aligned) {
150
151
152
153
154
155
  return ROIAlignFunction::apply(
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
AhnDW's avatar
AhnDW committed
156
157
      sampling_ratio,
      aligned)[0];
158
}