discounted_cumsum_kernel.cu 3.66 KB
Newer Older
1
2
3
4
#include <torch/extension.h>


template <typename scalar_t>
anton's avatar
anton committed
5
6
__device__ __forceinline__
scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) {
7
8
9
10
    return a + b * pow(gamma, scalar_t(power));
}


anton's avatar
anton committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
enum SumDirection {
    SUM_RIGHT,
    SUM_LEFT
};


template <SumDirection d>
__device__ __forceinline__
void resolve_positions(
        const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr,
        int &change_pos, int &discounted_pos, int &discount_power
);


template <>
__device__ __forceinline__
void resolve_positions<SUM_RIGHT>(
        const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr,
        int &change_pos, int &discounted_pos, int &discount_power
) {
    change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
    discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
    discount_power = gr_prev_stride - thread_in_gr;
34
35
36
}


anton's avatar
anton committed
37
38
39
template <typename scalar_t, SumDirection d>
__global__
void discounted_cumsum_kernel_stage(
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        torch::PackedTensorAccessor32<scalar_t, 2> x,
        const scalar_t gamma,
        int stage
) {
    const int len = x.size(1);
    const int threadidx = blockIdx.x * blockDim.x + threadIdx.x;
    const int threadidy = blockIdx.y * blockDim.y + threadIdx.y;

    if (threadidy >= x.size(0)) {
        return;
    }

    int gr_prev_stride = 1 << stage;
    int gr_cur_stride = gr_prev_stride << 1;

    int gr_of_thread = threadidx >> stage;
    int thread_in_gr = threadidx - (gr_of_thread << stage);

anton's avatar
anton committed
58
59
60
61
62
63
64
65
66
    //int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
    //int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
    //int discount_power = gr_prev_stride - thread_in_gr;

    int change_pos, discounted_pos, discount_power;
    resolve_positions<d>(
        gr_prev_stride, gr_cur_stride, gr_of_thread, thread_in_gr,
        change_pos, discounted_pos, discount_power
    );
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    if (change_pos >= len || discounted_pos >= len) {
        return;
    }

    x[threadidy][change_pos] = discounted_sum_pow(
        x[threadidy][change_pos],
        x[threadidy][discounted_pos],
        gamma,
        discount_power
    );
}


anton's avatar
anton committed
81
82
83
84
85
86
87
88
inline
int log2ceil(int x) {
    return (int)ceil(log2((float)x));
}


template <SumDirection d>
torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
anton's avatar
anton committed
89
90
    // Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
    // Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
91
92
93
94
95
96
97
98
99
100
101
102

    TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor");
    TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
    TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
    TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");

    if (x.size(1) == 0) {
        return x;
    }

    auto y = x.clone();

anton's avatar
anton committed
103
    const int threads = 64;
104
105
106
107
108
    const int nstages = log2ceil(x.size(1));
    const int threads_total_x = 1 << (nstages - 1);
    const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));

    for (int stage=0; stage<nstages; stage++) {
anton's avatar
anton committed
109
110
        AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_kernel_stage", ([&] {
            discounted_cumsum_kernel_stage<scalar_t, d><<<blocks, threads>>>(
111
112
113
114
115
116
117
118
119
                y.packed_accessor32<scalar_t, 2>(),
                scalar_t(gamma),
                stage
            );
        }));
    }

    return y;
}
anton's avatar
anton committed
120
121
122
123
124
125
126
127
128


torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
    return discounted_cumsum<SUM_RIGHT>(x, gamma);
}


//torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
//}