"vscode:/vscode.git/clone" did not exist on "eaa0c1dfb50dff7679c42c145a55bad7c0a728bf"
new_empty_tensor_op.cpp 1.2 KB
Newer Older
1
#include "new_empty_tensor_op.h"
2
3
4

#include <torch/autograd.h>
#include <torch/types.h>
5
6
7
8
9

namespace vision {
namespace ops {

namespace {
eellison's avatar
eellison committed
10
11
12

class NewEmptyTensorOp : public torch::autograd::Function<NewEmptyTensorOp> {
 public:
13
14
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
15
16
      const torch::autograd::Variable& input,
      const c10::List<int64_t>& new_shape) {
eellison's avatar
eellison committed
17
18
    ctx->saved_data["shape"] = input.sizes();
    std::vector<int64_t> shape(new_shape.begin(), new_shape.end());
19
    return {input.new_empty(shape, at::TensorOptions())};
eellison's avatar
eellison committed
20
21
  }

22
23
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
24
      const torch::autograd::variable_list& grad_output) {
eellison's avatar
eellison committed
25
26
27
28
29
30
31
    // Use data saved in forward
    auto shape = ctx->saved_data["shape"].toIntList();
    auto out = forward(ctx, grad_output[0], shape);
    return {out[0], at::Tensor()};
  }
};

32
33
34
35
36
} // namespace

at::Tensor new_empty_tensor(
    const at::Tensor& input,
    const c10::List<int64_t>& shape) {
eellison's avatar
eellison committed
37
38
  return NewEmptyTensorOp::apply(input, shape)[0];
}
39

40
41
42
43
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
  m.def("_new_empty_tensor_op", &new_empty_tensor);
}

44
45
} // namespace ops
} // namespace vision