interpolate_module.cpp 3.76 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#include <torch/script.h>

#include <ATen/autocast_mode.h>

#ifndef NO_PYBIND
#include <torch/extension.h>
#endif

#include "interpolate_kernel.h"

// Dispatch function
torch::Tensor interpolate(
    const torch::Tensor& vert_attributes,
    const torch::Tensor& vi,
    const torch::Tensor& index_img,
    const torch::Tensor& bary_img) {
  static auto op = torch::Dispatcher::singleton()
                       .findSchemaOrThrow("interpolate_ext::interpolate", "")
                       .typed<decltype(interpolate)>();
  return op.call(vert_attributes, vi, index_img, bary_img);
}

// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class InterpolateFunction : public torch::autograd::Function<InterpolateFunction> {
 public:
  static torch::autograd::tensor_list forward(
      torch::autograd::AutogradContext* ctx,
      const torch::Tensor& vert_attributes,
      const torch::Tensor& vi,
      const torch::Tensor& index_img,
      const torch::Tensor& bary_img) {
    ctx->set_materialize_grads(false);
    std::vector<torch::Tensor> save_list;
    save_list.push_back(vert_attributes);
    save_list.push_back(vi);
    save_list.push_back(index_img);
    save_list.push_back(bary_img);
    ctx->save_for_backward(save_list);
    return {interpolate_cuda(vert_attributes, vi, index_img, bary_img)};
  }

  static torch::autograd::tensor_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::tensor_list grad_outputs) {
    const auto saved = ctx->get_saved_variables();
    const torch::Tensor& vert_attributes = saved[0];
    const torch::Tensor& vi = saved[1];
    const torch::Tensor& index_img = saved[2];
    const torch::Tensor& bary_img = saved[3];
    bool bary_img_requires_grad = bary_img.requires_grad();
    bool vert_requires_grad = vert_attributes.requires_grad();

    torch::autograd::tensor_list out;
    if ((!bary_img_requires_grad && !vert_requires_grad) || !grad_outputs[0].defined()) {
      out.resize(4);
      return out;
    }
    auto grad_out =
        interpolate_cuda_backward(grad_outputs[0], vert_attributes, vi, index_img, bary_img);

    out.push_back(std::get<0>(grad_out));
    out.emplace_back();
    out.emplace_back();
    out.push_back(std::get<1>(grad_out));
    return out;
  }
};

torch::Tensor interpolate_autograd(
    const torch::Tensor& vert_attributes,
    const torch::Tensor& vi,
    const torch::Tensor& index_img,
    const torch::Tensor& bary_img) {
  return InterpolateFunction::apply(vert_attributes, vi, index_img, bary_img)[0];
}

torch::Tensor interpolate_autocast(
    const torch::Tensor& vert_attributes,
    const torch::Tensor& vi,
    const torch::Tensor& index_img,
    const torch::Tensor& bary_img) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return interpolate(
      at::autocast::cached_cast(torch::kFloat32, vert_attributes),
      vi,
      index_img,
      at::autocast::cached_cast(torch::kFloat32, bary_img));
}

#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(interpolate_ext, m) {}
#endif

TORCH_LIBRARY(interpolate_ext, m) {
  m.def(
      "interpolate(Tensor vert_attributes, Tensor vi, Tensor index_img, Tensor bary_img) -> Tensor");
}

TORCH_LIBRARY_IMPL(interpolate_ext, Autograd, m) {
  m.impl("interpolate", &interpolate_autograd);
}

TORCH_LIBRARY_IMPL(interpolate_ext, Autocast, m) {
  m.impl("interpolate", interpolate_autocast);
}

TORCH_LIBRARY_IMPL(interpolate_ext, CUDA, m) {
  m.impl("interpolate", &interpolate_cuda);
}