roi_pool.cpp 3.02 KB
Newer Older
1
#include "roi_pool.h"
2

limm's avatar
limm committed
3
4
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
5
#include <torch/types.h>
6

7
8
namespace vision {
namespace ops {
9
10

std::tuple<at::Tensor, at::Tensor> roi_pool(
11
12
    const at::Tensor& input,
    const at::Tensor& rois,
13
14
15
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width) {
limm's avatar
limm committed
16
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool");
17
18
19
20
21
22
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::roi_pool", "")
                       .typed<decltype(roi_pool)>();
  return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}

limm's avatar
limm committed
23
24
25
26
27
28
29
30
31
32
33
34
35
std::tuple<at::Tensor, at::Tensor> roi_pool_symint(
    const at::Tensor& input,
    const at::Tensor& rois,
    double spatial_scale,
    c10::SymInt pooled_height,
    c10::SymInt pooled_width) {
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_pool.roi_pool");
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::roi_pool", "")
                       .typed<decltype(roi_pool_symint)>();
  return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}

36
37
namespace detail {

38
at::Tensor _roi_pool_backward(
39
40
41
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& argmax,
42
43
44
45
46
47
48
49
50
51
52
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::_roi_pool_backward", "")
                       .typed<decltype(_roi_pool_backward)>();
  return op.call(
53
54
55
56
57
58
59
60
61
62
      grad,
      rois,
      argmax,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width);
63
64
}

limm's avatar
limm committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
at::Tensor _roi_pool_backward_symint(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& argmax,
    double spatial_scale,
    c10::SymInt pooled_height,
    c10::SymInt pooled_width,
    c10::SymInt batch_size,
    c10::SymInt channels,
    c10::SymInt height,
    c10::SymInt width) {
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::_roi_pool_backward", "")
                       .typed<decltype(_roi_pool_backward_symint)>();
  return op.call(
      grad,
      rois,
      argmax,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width);
}

92
93
} // namespace detail

94
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
95
  m.def(TORCH_SELECTIVE_SCHEMA(
limm's avatar
limm committed
96
      "torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)"));
97
  m.def(TORCH_SELECTIVE_SCHEMA(
limm's avatar
limm committed
98
      "torchvision::_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"));
99
100
}

101
102
} // namespace ops
} // namespace vision