"vscode:/vscode.git/clone" did not exist on "cc88d98ab81bd47aa0f5a4de7d6a6f49124baedb"
empty_tensor_op.h 983 Bytes
Newer Older
eellison's avatar
eellison committed
1
2
3
4
5
6
7
#pragma once

// All pure C++ headers for the C++ frontend.
#include <torch/all.h>

class NewEmptyTensorOp : public torch::autograd::Function<NewEmptyTensorOp> {
 public:
8
9
10
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::Variable input,
eellison's avatar
eellison committed
11
12
13
      c10::List<int64_t> new_shape) {
    ctx->saved_data["shape"] = input.sizes();
    std::vector<int64_t> shape(new_shape.begin(), new_shape.end());
14
    return {input.new_empty(shape, at::TensorOptions())};
eellison's avatar
eellison committed
15
16
  }

17
18
19
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_output) {
eellison's avatar
eellison committed
20
21
22
23
24
25
26
    // Use data saved in forward
    auto shape = ctx->saved_data["shape"].toIntList();
    auto out = forward(ctx, grad_output[0], shape);
    return {out[0], at::Tensor()};
  }
};

27
at::Tensor new_empty_tensor(const at::Tensor& input, c10::List<int64_t> shape) {
eellison's avatar
eellison committed
28
29
  return NewEmptyTensorOp::apply(input, shape)[0];
}