#include template inline void discounted_sum_update(T_accessor &accessor, int batchsz, scalar_t gamma, int change_pos, int discounted_pos) { for (int i=0; i(); for (int j=0; j(); for (int j=y.size(1)-1; j>=0; j--) { int j_right = j+1; if (j_right == 0) { continue; } discounted_sum_update(ya, y.size(0), gamma, j, j_right); } })); return y; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("discounted_cumsum_left_cpu", &discounted_cumsum_left_cpu, "Discounted Cumulative Sum CPU (Left)"); m.def("discounted_cumsum_right_cpu", &discounted_cumsum_right_cpu, "Discounted Cumulative Sum CPU (Right)"); }