roi_align.cpp 754 Bytes
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <torch/torch.h>
// CPU declarations

at::Tensor ROIAlignForwardCPU(
  const at::Tensor& input,
  const at::Tensor& bottom_rois,
  int64_t pooled_height,
  int64_t pooled_width,
  double spatial_scale,
  int64_t sampling_ratio);

at::Tensor ROIAlignBackwardCPU(
  const at::Tensor& bottom_rois,
  const at::Tensor& grad_output, // gradient of the output of the layer
  int64_t b_size,
  int64_t channels,
  int64_t height,
  int64_t width,
  int64_t pooled_height,
  int64_t pooled_width,
  double spatial_scale,
  int64_t sampling_ratio);


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("roi_align_forward", &ROIAlignForwardCPU, "ROI Align forward (CPU)");
  m.def("roi_align_backward", &ROIAlignBackwardCPU, "ROI Align backward (CPU)");
}