edge_grad_module.cpp 7.1 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#include <torch/script.h>

#include <ATen/autocast_mode.h>

#ifndef NO_PYBIND
#include <torch/extension.h>
#endif

#include "edge_grad_kernel.h"

// Dispatch function
torch::Tensor edge_grad_estimator(
    const torch::Tensor& v_pix,
    const torch::Tensor& v_pix_img,
    const torch::Tensor& vi,
    const torch::Tensor& img,
    const torch::Tensor& index_img) {
  static auto op = torch::Dispatcher::singleton()
                       .findSchemaOrThrow("edge_grad_ext::edge_grad_estimator", "")
                       .typed<decltype(edge_grad_estimator)>();
  return op.call(v_pix, v_pix_img, vi, img, index_img);
}

torch::Tensor edge_grad_estimator_fwd(
    const torch::Tensor& v_pix,
    const torch::Tensor& v_pix_img,
    const torch::Tensor& vi,
    const torch::Tensor& img,
    const torch::Tensor& index_img) {
  TORCH_CHECK(
      v_pix.defined() && v_pix_img.defined() && vi.defined() && img.defined() &&
          index_img.defined(),
      "edge_grad_estimator(): expected all inputs to be defined");
  TORCH_CHECK(
      (v_pix.device() == v_pix_img.device()) && (v_pix.device() == vi.device()) &&
          (v_pix.device() == img.device()) && (v_pix.device() == index_img.device()) &&
          (v_pix.is_cuda()),
      "edge_grad_estimator(): expected all inputs to be on same cuda device");
  TORCH_CHECK(
      v_pix.is_floating_point() && v_pix_img.is_floating_point() && img.is_floating_point(),
      "edge_grad_estimator(): expected v_pix, v_pix_img, and img to have floating point type, but v_pix has ",
      v_pix.dtype(),
      " v_pix has ",
      v_pix_img.dtype(),
      " img has ",
      img.dtype());
  TORCH_CHECK(
      vi.dtype() == torch::kInt32,
      "edge_grad_estimator(): expected vi to have int32 type, but vi has ",
      vi.dtype());
  TORCH_CHECK(
      index_img.dtype() == torch::kInt32,
      "edge_grad_estimator(): expected index_img to have int32 type, but index_img has ",
      index_img.dtype());
  TORCH_CHECK(
      v_pix.layout() == torch::kStrided && v_pix_img.layout() == torch::kStrided &&
          vi.layout() == torch::kStrided && img.layout() == torch::kStrided &&
          index_img.layout() == torch::kStrided,
      "edge_grad_estimator(): expected all inputs to have torch.strided layout");
  TORCH_CHECK(
      (v_pix.dim() == 3) && (v_pix_img.dim() == 4) && (vi.dim() == 2) && (img.dim() == 4) &&
          (index_img.dim() == 3),
      "edge_grad_estimator(): expected v_pix.ndim == 3, v_pix_img.ndim == 4, vi.ndim == 2, img.ndim == 4, index_img.ndim == 3, "
      "but got v_pix with sizes ",
      v_pix.sizes(),
      " and v_pix_img with sizes ",
      v_pix_img.sizes(),
      " and vi with sizes ",
      vi.sizes(),
      " and img with sizes ",
      img.sizes(),
      " and index_img with sizes ",
      index_img.sizes());
  TORCH_CHECK(
      v_pix.size(0) == v_pix_img.size(0) && v_pix.size(0) == img.size(0) &&
          v_pix.size(0) == index_img.size(0),
      "edge_grad_estimator(): expected v and index_img to have same batch size, "
      "but got v_pix with sizes ",
      v_pix.sizes(),
      ", v_pix_img with sizes ",
      v_pix_img.sizes(),
      ", img with sizes ",
      img.sizes(),
      " and index_img with sizes ",
      index_img.sizes());
  TORCH_CHECK(
      v_pix.size(2) == 3 && v_pix_img.size(1) == 3 && vi.size(1) == 3,
      "edge_grad_estimator(): expected third dim of v_pix to be of size 3, and second dim of vi to be of size 3, but got ",
      v_pix.size(2),
      " in the third dim of v_pix, and ",
      v_pix_img.size(1),
      " in the second dim of v_pix_img, and ",
      vi.size(1),
      " in the second dim of vi");
  TORCH_CHECK(
      v_pix_img.size(3) == img.size(3) && v_pix_img.size(3) == index_img.size(2) &&
          v_pix_img.size(2) == img.size(2) && v_pix_img.size(2) == index_img.size(1),
      "edge_grad_estimator(): expected width and height of v_pix_img, img, and index_img to match, but got size of v_pix_img: ",
      v_pix_img.sizes(),
      ", size of img: ",
      img.sizes(),
      ", size of index_img: ",
      index_img.sizes());
  return img;
}

// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class EdgeGradEstimatorFunction : public torch::autograd::Function<EdgeGradEstimatorFunction> {
 public:
  static torch::autograd::tensor_list forward(
      torch::autograd::AutogradContext* ctx,
      const torch::Tensor& v_pix,
      const torch::Tensor& v_pix_img,
      const torch::Tensor& vi,
      const torch::Tensor& img,
      const torch::Tensor& index_img) {
    ctx->set_materialize_grads(false);
    ctx->save_for_backward({v_pix, img, index_img, vi});
    ctx->saved_data["v_pix_img_requires_grad"] = v_pix_img.requires_grad();
    return {img};
  }

  static torch::autograd::tensor_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::tensor_list grad_outputs) {
    // If v_pix_img doesn't require grad, we don't need to do anything.
    if (!ctx->saved_data["v_pix_img_requires_grad"].toBool()) {
      return {torch::Tensor(), torch::Tensor(), torch::Tensor(), grad_outputs[0], torch::Tensor()};
    }
    const auto saved = ctx->get_saved_variables();
    const auto& v_pix = saved[0];
    const auto& img = saved[1];
    const auto& index_img = saved[2];
    const auto& vi = saved[3];

    auto grad_v_pix_img =
        edge_grad_estimator_cuda_backward(v_pix, img, index_img, vi, grad_outputs[0]);
    return {torch::Tensor(), grad_v_pix_img, torch::Tensor(), grad_outputs[0], torch::Tensor()};
  }
};

torch::Tensor edge_grad_estimator_autograd(
    const torch::Tensor& v_pix,
    const torch::Tensor& v_pix_img,
    const torch::Tensor& vi,
    const torch::Tensor& img,
    const torch::Tensor& index_img) {
  return EdgeGradEstimatorFunction::apply(v_pix, v_pix_img, vi, img, index_img)[0];
}

torch::Tensor edge_grad_estimator_autocast(
    const torch::Tensor& v_pix,
    const torch::Tensor& v_pix_img,
    const torch::Tensor& vi,
    const torch::Tensor& img,
    const torch::Tensor& index_img) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return edge_grad_estimator(
      at::autocast::cached_cast(torch::kFloat32, v_pix),
      at::autocast::cached_cast(torch::kFloat32, v_pix_img),
      vi,
      at::autocast::cached_cast(torch::kFloat32, img),
      index_img)[0];
}

#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(edge_grad_ext, m) {}
#endif

TORCH_LIBRARY(edge_grad_ext, m) {
  m.def(
      "edge_grad_estimator(Tensor v_pix, Tensor v_pix_img, Tensor vi, Tensor img, Tensor index_img) -> Tensor");
}

TORCH_LIBRARY_IMPL(edge_grad_ext, Autograd, m) {
  m.impl("edge_grad_estimator", &edge_grad_estimator_autograd);
}

TORCH_LIBRARY_IMPL(edge_grad_ext, Autocast, m) {
  m.impl("edge_grad_estimator", edge_grad_estimator_autocast);
}

TORCH_LIBRARY_IMPL(edge_grad_ext, CUDA, m) {
  m.impl("edge_grad_estimator", &edge_grad_estimator_fwd);
}