edge_grad_module.cpp 7.2 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
Stanislav Pidhorskyi's avatar
Stanislav Pidhorskyi committed
3
// This source code is licensed under the MIT license found in the
facebook-github-bot's avatar
facebook-github-bot committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// LICENSE file in the root directory of this source tree.

#include <torch/script.h>

#include <ATen/autocast_mode.h>

#ifndef NO_PYBIND
#include <torch/extension.h>
#endif

#include "edge_grad_kernel.h"

// Dispatch function
torch::Tensor edge_grad_estimator(
    const torch::Tensor& v_pix,
    const torch::Tensor& v_pix_img,
    const torch::Tensor& vi,
    const torch::Tensor& img,
    const torch::Tensor& index_img) {
  static auto op = torch::Dispatcher::singleton()
                       .findSchemaOrThrow("edge_grad_ext::edge_grad_estimator", "")
                       .typed<decltype(edge_grad_estimator)>();
  return op.call(v_pix, v_pix_img, vi, img, index_img);
}

torch::Tensor edge_grad_estimator_fwd(
    const torch::Tensor& v_pix,
    const torch::Tensor& v_pix_img,
    const torch::Tensor& vi,
    const torch::Tensor& img,
    const torch::Tensor& index_img) {
  TORCH_CHECK(
      v_pix.defined() && v_pix_img.defined() && vi.defined() && img.defined() &&
          index_img.defined(),
      "edge_grad_estimator(): expected all inputs to be defined");
  TORCH_CHECK(
      (v_pix.device() == v_pix_img.device()) && (v_pix.device() == vi.device()) &&
          (v_pix.device() == img.device()) && (v_pix.device() == index_img.device()) &&
          (v_pix.is_cuda()),
      "edge_grad_estimator(): expected all inputs to be on same cuda device");
  TORCH_CHECK(
      v_pix.is_floating_point() && v_pix_img.is_floating_point() && img.is_floating_point(),
      "edge_grad_estimator(): expected v_pix, v_pix_img, and img to have floating point type, but v_pix has ",
      v_pix.dtype(),
      " v_pix has ",
      v_pix_img.dtype(),
      " img has ",
      img.dtype());
  TORCH_CHECK(
      vi.dtype() == torch::kInt32,
      "edge_grad_estimator(): expected vi to have int32 type, but vi has ",
      vi.dtype());
  TORCH_CHECK(
      index_img.dtype() == torch::kInt32,
      "edge_grad_estimator(): expected index_img to have int32 type, but index_img has ",
      index_img.dtype());
  TORCH_CHECK(
      v_pix.layout() == torch::kStrided && v_pix_img.layout() == torch::kStrided &&
          vi.layout() == torch::kStrided && img.layout() == torch::kStrided &&
          index_img.layout() == torch::kStrided,
      "edge_grad_estimator(): expected all inputs to have torch.strided layout");
  TORCH_CHECK(
Tomas Simon's avatar
Tomas Simon committed
66
      (v_pix.dim() == 3) && (v_pix_img.dim() == 4) && (vi.dim() == 3) && (img.dim() == 4) &&
facebook-github-bot's avatar
facebook-github-bot committed
67
          (index_img.dim() == 3),
Tomas Simon's avatar
Tomas Simon committed
68
      "edge_grad_estimator(): expected v_pix.ndim == 3, v_pix_img.ndim == 4, vi.ndim == 3, img.ndim == 4, index_img.ndim == 3, "
facebook-github-bot's avatar
facebook-github-bot committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
      "but got v_pix with sizes ",
      v_pix.sizes(),
      " and v_pix_img with sizes ",
      v_pix_img.sizes(),
      " and vi with sizes ",
      vi.sizes(),
      " and img with sizes ",
      img.sizes(),
      " and index_img with sizes ",
      index_img.sizes());
  TORCH_CHECK(
      v_pix.size(0) == v_pix_img.size(0) && v_pix.size(0) == img.size(0) &&
          v_pix.size(0) == index_img.size(0),
      "edge_grad_estimator(): expected v and index_img to have same batch size, "
      "but got v_pix with sizes ",
      v_pix.sizes(),
      ", v_pix_img with sizes ",
      v_pix_img.sizes(),
      ", img with sizes ",
      img.sizes(),
      " and index_img with sizes ",
      index_img.sizes());
  TORCH_CHECK(
Tomas Simon's avatar
Tomas Simon committed
92
93
      v_pix.size(2) == 3 && v_pix_img.size(1) == 3 && vi.size(2) == 3,
      "edge_grad_estimator(): expected third dim of v_pix to be of size 3, and third dim of vi to be of size 3, but got ",
facebook-github-bot's avatar
facebook-github-bot committed
94
95
96
97
      v_pix.size(2),
      " in the third dim of v_pix, and ",
      v_pix_img.size(1),
      " in the second dim of v_pix_img, and ",
Tomas Simon's avatar
Tomas Simon committed
98
99
      vi.size(2),
      " in the third dim of vi");
facebook-github-bot's avatar
facebook-github-bot committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  TORCH_CHECK(
      v_pix_img.size(3) == img.size(3) && v_pix_img.size(3) == index_img.size(2) &&
          v_pix_img.size(2) == img.size(2) && v_pix_img.size(2) == index_img.size(1),
      "edge_grad_estimator(): expected width and height of v_pix_img, img, and index_img to match, but got size of v_pix_img: ",
      v_pix_img.sizes(),
      ", size of img: ",
      img.sizes(),
      ", size of index_img: ",
      index_img.sizes());
  return img;
}

// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class EdgeGradEstimatorFunction : public torch::autograd::Function<EdgeGradEstimatorFunction> {
 public:
  static torch::autograd::tensor_list forward(
      torch::autograd::AutogradContext* ctx,
      const torch::Tensor& v_pix,
      const torch::Tensor& v_pix_img,
      const torch::Tensor& vi,
      const torch::Tensor& img,
      const torch::Tensor& index_img) {
Tomas Simon's avatar
Tomas Simon committed
123
124
    // Call edge_grad_estimator_fwd to check the input sizes
    edge_grad_estimator_fwd(v_pix, v_pix_img, vi, img, index_img);
facebook-github-bot's avatar
facebook-github-bot committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    ctx->set_materialize_grads(false);
    ctx->save_for_backward({v_pix, img, index_img, vi});
    ctx->saved_data["v_pix_img_requires_grad"] = v_pix_img.requires_grad();
    return {img};
  }

  static torch::autograd::tensor_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::tensor_list grad_outputs) {
    // If v_pix_img doesn't require grad, we don't need to do anything.
    if (!ctx->saved_data["v_pix_img_requires_grad"].toBool()) {
      return {torch::Tensor(), torch::Tensor(), torch::Tensor(), grad_outputs[0], torch::Tensor()};
    }
    const auto saved = ctx->get_saved_variables();
    const auto& v_pix = saved[0];
    const auto& img = saved[1];
    const auto& index_img = saved[2];
    const auto& vi = saved[3];

    auto grad_v_pix_img =
        edge_grad_estimator_cuda_backward(v_pix, img, index_img, vi, grad_outputs[0]);
    return {torch::Tensor(), grad_v_pix_img, torch::Tensor(), grad_outputs[0], torch::Tensor()};
  }
};

torch::Tensor edge_grad_estimator_autograd(
    const torch::Tensor& v_pix,
    const torch::Tensor& v_pix_img,
    const torch::Tensor& vi,
    const torch::Tensor& img,
    const torch::Tensor& index_img) {
  return EdgeGradEstimatorFunction::apply(v_pix, v_pix_img, vi, img, index_img)[0];
}

torch::Tensor edge_grad_estimator_autocast(
    const torch::Tensor& v_pix,
    const torch::Tensor& v_pix_img,
    const torch::Tensor& vi,
    const torch::Tensor& img,
    const torch::Tensor& index_img) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return edge_grad_estimator(
      at::autocast::cached_cast(torch::kFloat32, v_pix),
      at::autocast::cached_cast(torch::kFloat32, v_pix_img),
      vi,
      at::autocast::cached_cast(torch::kFloat32, img),
      index_img)[0];
}

#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(edge_grad_ext, m) {}
#endif

TORCH_LIBRARY(edge_grad_ext, m) {
  m.def(
      "edge_grad_estimator(Tensor v_pix, Tensor v_pix_img, Tensor vi, Tensor img, Tensor index_img) -> Tensor");
}

TORCH_LIBRARY_IMPL(edge_grad_ext, Autograd, m) {
  m.impl("edge_grad_estimator", &edge_grad_estimator_autograd);
}

TORCH_LIBRARY_IMPL(edge_grad_ext, Autocast, m) {
  m.impl("edge_grad_estimator", edge_grad_estimator_autocast);
}

TORCH_LIBRARY_IMPL(edge_grad_ext, CUDA, m) {
  m.impl("edge_grad_estimator", &edge_grad_estimator_fwd);
}