discounted_cumsum.cpp 416 Bytes
Newer Older
1
2
#include <torch/extension.h>

anton's avatar
anton committed
3

anton's avatar
anton committed
4
torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma);
anton's avatar
anton committed
5
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma);
6
7
8


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
anton's avatar
anton committed
9
10
    m.def("discounted_cumsum_left", &discounted_cumsum_left, "Discounted Cumulative Sum (Left)");
    m.def("discounted_cumsum_right", &discounted_cumsum_right, "Discounted Cumulative Sum (Right)");
11
}