discounted_cumsum_kernel.cu 2.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <torch/extension.h>


template <typename scalar_t>
__device__ __forceinline__ scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) {
    return a + b * pow(gamma, scalar_t(power));
}


__inline__
int log2ceil(int x) {
    return (int)ceil(log2((float)x));
}


template <typename scalar_t>
anton's avatar
anton committed
17
__global__ void discounted_cumsum_right_kernel_stage(
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        torch::PackedTensorAccessor32<scalar_t, 2> x,
        const scalar_t gamma,
        int stage
) {
    const int len = x.size(1);
    const int threadidx = blockIdx.x * blockDim.x + threadIdx.x;
    const int threadidy = blockIdx.y * blockDim.y + threadIdx.y;

    if (threadidy >= x.size(0)) {
        return;
    }

    int gr_prev_stride = 1 << stage;
    int gr_cur_stride = gr_prev_stride << 1;

    int gr_of_thread = threadidx >> stage;
    int thread_in_gr = threadidx - (gr_of_thread << stage);

    int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
    int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
    int discount_power = gr_prev_stride - thread_in_gr;

    if (change_pos >= len || discounted_pos >= len) {
        return;
    }

    x[threadidy][change_pos] = discounted_sum_pow(
        x[threadidy][change_pos],
        x[threadidy][discounted_pos],
        gamma,
        discount_power
    );
}


anton's avatar
anton committed
53
54
55
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
    // Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
    // Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
56
57
58
59
60
61
62
63
64
65
66
67

    TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor");
    TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
    TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
    TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");

    if (x.size(1) == 0) {
        return x;
    }

    auto y = x.clone();

anton's avatar
anton committed
68
    const int threads = 64;
69
70
71
72
73
    const int nstages = log2ceil(x.size(1));
    const int threads_total_x = 1 << (nstages - 1);
    const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));

    for (int stage=0; stage<nstages; stage++) {
anton's avatar
anton committed
74
75
        AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_right_kernel_stage", ([&] {
            discounted_cumsum_right_kernel_stage<scalar_t><<<blocks, threads>>>(
76
77
78
79
80
81
82
83
84
                y.packed_accessor32<scalar_t, 2>(),
                scalar_t(gamma),
                stage
            );
        }));
    }

    return y;
}