ps_roi_pool.cpp 3.08 KB
Newer Older
1
#include "ps_roi_pool.h"
2

limm's avatar
limm committed
3
4
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
5
#include <torch/types.h>
6

7
8
namespace vision {
namespace ops {
9
10

std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
11
12
    const at::Tensor& input,
    const at::Tensor& rois,
13
14
15
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width) {
limm's avatar
limm committed
16
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool");
17
18
19
20
21
22
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::ps_roi_pool", "")
                       .typed<decltype(ps_roi_pool)>();
  return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}

limm's avatar
limm committed
23
24
25
26
27
28
29
30
31
32
33
34
35
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_symint(
    const at::Tensor& input,
    const at::Tensor& rois,
    double spatial_scale,
    c10::SymInt pooled_height,
    c10::SymInt pooled_width) {
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool");
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::ps_roi_pool", "")
                       .typed<decltype(ps_roi_pool_symint)>();
  return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}

36
37
namespace detail {

38
at::Tensor _ps_roi_pool_backward(
39
40
    const at::Tensor& grad,
    const at::Tensor& rois,
41
42
43
44
45
46
47
48
49
50
51
52
53
    const at::Tensor& channel_mapping,
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "")
          .typed<decltype(_ps_roi_pool_backward)>();
  return op.call(
54
55
      grad,
      rois,
56
      channel_mapping,
57
58
59
60
61
62
63
64
65
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width);
}

limm's avatar
limm committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
at::Tensor _ps_roi_pool_backward_symint(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& channel_mapping,
    double spatial_scale,
    c10::SymInt pooled_height,
    c10::SymInt pooled_width,
    c10::SymInt batch_size,
    c10::SymInt channels,
    c10::SymInt height,
    c10::SymInt width) {
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "")
          .typed<decltype(_ps_roi_pool_backward_symint)>();
  return op.call(
      grad,
      rois,
      channel_mapping,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width);
}

94
95
} // namespace detail

96
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
97
  m.def(TORCH_SELECTIVE_SCHEMA(
limm's avatar
limm committed
98
      "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)"));
99
  m.def(TORCH_SELECTIVE_SCHEMA(
limm's avatar
limm committed
100
      "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"));
101
102
}

103
104
} // namespace ops
} // namespace vision