Commit 1613a3eb authored by anton's avatar anton
Browse files

refactor sum direction into templates

parent 4c605c1e
...@@ -2,19 +2,41 @@ ...@@ -2,19 +2,41 @@
template <typename scalar_t> template <typename scalar_t>
__device__ __forceinline__ scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) { __device__ __forceinline__
scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) {
return a + b * pow(gamma, scalar_t(power)); return a + b * pow(gamma, scalar_t(power));
} }
__inline__ enum SumDirection {
int log2ceil(int x) { SUM_RIGHT,
return (int)ceil(log2((float)x)); SUM_LEFT
};
template <SumDirection d>
__device__ __forceinline__
void resolve_positions(
const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr,
int &change_pos, int &discounted_pos, int &discount_power
);
template <>
__device__ __forceinline__
void resolve_positions<SUM_RIGHT>(
const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr,
int &change_pos, int &discounted_pos, int &discount_power
) {
change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
discount_power = gr_prev_stride - thread_in_gr;
} }
template <typename scalar_t> template <typename scalar_t, SumDirection d>
__global__ void discounted_cumsum_right_kernel_stage( __global__
void discounted_cumsum_kernel_stage(
torch::PackedTensorAccessor32<scalar_t, 2> x, torch::PackedTensorAccessor32<scalar_t, 2> x,
const scalar_t gamma, const scalar_t gamma,
int stage int stage
...@@ -33,9 +55,15 @@ __global__ void discounted_cumsum_right_kernel_stage( ...@@ -33,9 +55,15 @@ __global__ void discounted_cumsum_right_kernel_stage(
int gr_of_thread = threadidx >> stage; int gr_of_thread = threadidx >> stage;
int thread_in_gr = threadidx - (gr_of_thread << stage); int thread_in_gr = threadidx - (gr_of_thread << stage);
int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr; //int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride; //int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
int discount_power = gr_prev_stride - thread_in_gr; //int discount_power = gr_prev_stride - thread_in_gr;
int change_pos, discounted_pos, discount_power;
resolve_positions<d>(
gr_prev_stride, gr_cur_stride, gr_of_thread, thread_in_gr,
change_pos, discounted_pos, discount_power
);
if (change_pos >= len || discounted_pos >= len) { if (change_pos >= len || discounted_pos >= len) {
return; return;
...@@ -50,7 +78,14 @@ __global__ void discounted_cumsum_right_kernel_stage( ...@@ -50,7 +78,14 @@ __global__ void discounted_cumsum_right_kernel_stage(
} }
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) { inline
int log2ceil(int x) {
return (int)ceil(log2((float)x));
}
template <SumDirection d>
torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration. // Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling. // Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
...@@ -71,8 +106,8 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) { ...@@ -71,8 +106,8 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0)); const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));
for (int stage=0; stage<nstages; stage++) { for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_right_kernel_stage", ([&] { AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_kernel_stage", ([&] {
discounted_cumsum_right_kernel_stage<scalar_t><<<blocks, threads>>>( discounted_cumsum_kernel_stage<scalar_t, d><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(), y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma), scalar_t(gamma),
stage stage
...@@ -82,3 +117,12 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) { ...@@ -82,3 +117,12 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
return y; return y;
} }
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_RIGHT>(x, gamma);
}
//torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
//}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment