Commit 1613a3eb authored by anton's avatar anton
Browse files

refactor sum direction into templates

parent 4c605c1e
......@@ -2,19 +2,41 @@
template <typename scalar_t>
__device__ __forceinline__ scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) {
__device__ __forceinline__
scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) {
return a + b * pow(gamma, scalar_t(power));
}
__inline__
int log2ceil(int x) {
return (int)ceil(log2((float)x));
enum SumDirection {
SUM_RIGHT,
SUM_LEFT
};
template <SumDirection d>
__device__ __forceinline__
void resolve_positions(
const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr,
int &change_pos, int &discounted_pos, int &discount_power
);
template <>
__device__ __forceinline__
void resolve_positions<SUM_RIGHT>(
const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr,
int &change_pos, int &discounted_pos, int &discount_power
) {
change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
discount_power = gr_prev_stride - thread_in_gr;
}
template <typename scalar_t>
__global__ void discounted_cumsum_right_kernel_stage(
template <typename scalar_t, SumDirection d>
__global__
void discounted_cumsum_kernel_stage(
torch::PackedTensorAccessor32<scalar_t, 2> x,
const scalar_t gamma,
int stage
......@@ -33,9 +55,15 @@ __global__ void discounted_cumsum_right_kernel_stage(
int gr_of_thread = threadidx >> stage;
int thread_in_gr = threadidx - (gr_of_thread << stage);
int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
int discount_power = gr_prev_stride - thread_in_gr;
//int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
//int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
//int discount_power = gr_prev_stride - thread_in_gr;
int change_pos, discounted_pos, discount_power;
resolve_positions<d>(
gr_prev_stride, gr_cur_stride, gr_of_thread, thread_in_gr,
change_pos, discounted_pos, discount_power
);
if (change_pos >= len || discounted_pos >= len) {
return;
......@@ -50,7 +78,14 @@ __global__ void discounted_cumsum_right_kernel_stage(
}
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
inline
int log2ceil(int x) {
return (int)ceil(log2((float)x));
}
template <SumDirection d>
torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
......@@ -71,8 +106,8 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));
for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_right_kernel_stage", ([&] {
discounted_cumsum_right_kernel_stage<scalar_t><<<blocks, threads>>>(
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_kernel_stage", ([&] {
discounted_cumsum_kernel_stage<scalar_t, d><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma),
stage
......@@ -82,3 +117,12 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
return y;
}
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_RIGHT>(x, gamma);
}
//torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
//}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment