ROIAlign.h 4.21 KB
Newer Older
1
2
#pragma once

3
#include "cpu/vision_cpu.h"
4
5

#ifdef WITH_CUDA
6
#include "cuda/vision_cuda.h"
7
#endif
8
9
10
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif
11
12
13
14
15

// Interface for Python
at::Tensor ROIAlign_forward(
    const at::Tensor& input, // Input feature map.
    const at::Tensor& rois, // List of ROIs to pool over.
16
    const double spatial_scale, // The scale of the image features. ROIs will be
17
    // scaled to this.
18
19
    const int64_t pooled_height, // The height of the pooled feature map.
    const int64_t pooled_width, // The width of the pooled feature
AhnDW's avatar
AhnDW committed
20
21
    const int64_t sampling_ratio, // The number of points to sample in each bin
    const bool aligned) // The flag for pixel shift
22
23
24
// along each axis.
{
  if (input.type().is_cuda()) {
25
#if defined(WITH_CUDA) || defined(WITH_HIP)
26
27
28
29
30
31
    return ROIAlign_forward_cuda(
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
AhnDW's avatar
AhnDW committed
32
33
        sampling_ratio,
        aligned);
34
35
36
37
38
#else
    AT_ERROR("Not compiled with GPU support");
#endif
  }
  return ROIAlign_forward_cpu(
AhnDW's avatar
AhnDW committed
39
      input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned);
40
41
42
43
44
45
46
47
48
49
50
51
}

at::Tensor ROIAlign_backward(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const float spatial_scale,
    const int pooled_height,
    const int pooled_width,
    const int batch_size,
    const int channels,
    const int height,
    const int width,
AhnDW's avatar
AhnDW committed
52
53
    const int sampling_ratio,
    const bool aligned) {
54
  if (grad.type().is_cuda()) {
55
#if defined(WITH_CUDA) || defined(WITH_HIP)
56
57
58
59
60
61
62
63
64
65
    return ROIAlign_backward_cuda(
        grad,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        batch_size,
        channels,
        height,
        width,
AhnDW's avatar
AhnDW committed
66
67
        sampling_ratio,
        aligned);
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#else
    AT_ERROR("Not compiled with GPU support");
#endif
  }
  return ROIAlign_backward_cpu(
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
AhnDW's avatar
AhnDW committed
82
83
      sampling_ratio,
      aligned);
84
}
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
 public:
  static variable_list forward(
      AutogradContext* ctx,
      Variable input,
      Variable rois,
      const double spatial_scale,
      const int64_t pooled_height,
      const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
101
102
      const int64_t sampling_ratio,
      const bool aligned) {
103
104
105
106
    ctx->saved_data["spatial_scale"] = spatial_scale;
    ctx->saved_data["pooled_height"] = pooled_height;
    ctx->saved_data["pooled_width"] = pooled_width;
    ctx->saved_data["sampling_ratio"] = sampling_ratio;
AhnDW's avatar
AhnDW committed
107
    ctx->saved_data["aligned"] = aligned;
108
109
110
111
112
113
114
115
    ctx->saved_data["input_shape"] = input.sizes();
    ctx->save_for_backward({rois});
    auto result = ROIAlign_forward(
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
AhnDW's avatar
AhnDW committed
116
117
        sampling_ratio,
        aligned);
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    return {result};
  }

  static variable_list backward(
      AutogradContext* ctx,
      variable_list grad_output) {
    // Use data saved in forward
    auto saved = ctx->get_saved_variables();
    auto rois = saved[0];
    auto input_shape = ctx->saved_data["input_shape"].toIntList();
    auto grad_in = ROIAlign_backward(
        grad_output[0],
        rois,
        ctx->saved_data["spatial_scale"].toDouble(),
        ctx->saved_data["pooled_height"].toInt(),
        ctx->saved_data["pooled_width"].toInt(),
        input_shape[0],
        input_shape[1],
        input_shape[2],
        input_shape[3],
AhnDW's avatar
AhnDW committed
138
139
        ctx->saved_data["sampling_ratio"].toInt(),
        ctx->saved_data["aligned"].toBool());
140
    return {
AhnDW's avatar
AhnDW committed
141
        grad_in, Variable(), Variable(), Variable(), Variable(), Variable(), Variable()};
142
143
144
145
146
147
148
149
150
  }
};

Tensor roi_align(
    const Tensor& input,
    const Tensor& rois,
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
151
152
    const int64_t sampling_ratio,
    const bool aligned) {
153
154
155
156
157
158
  return ROIAlignFunction::apply(
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
AhnDW's avatar
AhnDW committed
159
160
      sampling_ratio,
      aligned)[0];
161
}