discounted_cumsum_kernel.cu 4.15 KB
Newer Older
1
2
3
4
#include <torch/extension.h>


template <typename scalar_t>
anton's avatar
anton committed
5
__device__ __forceinline__
anton's avatar
anton committed
6
scalar_t discounted_sum_power(scalar_t a, scalar_t b, scalar_t gamma, int power) {
7
8
9
10
    return a + b * pow(gamma, scalar_t(power));
}


anton's avatar
anton committed
11
enum SumDirection {
anton's avatar
anton committed
12
13
    SUM_DIRECTION_LEFT,
    SUM_DIRECTION_RIGHT,
anton's avatar
anton committed
14
15
16
};


anton's avatar
anton committed
17
template <SumDirection sum_direction>
anton's avatar
anton committed
18
19
__device__ __forceinline__
void resolve_positions(
anton's avatar
anton committed
20
21
    const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
    int &change_pos, int &discounted_pos, int &discount_power
anton's avatar
anton committed
22
23
24
25
26
);


template <>
__device__ __forceinline__
anton's avatar
anton committed
27
28
29
void resolve_positions<SUM_DIRECTION_LEFT>(
    const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
    int &change_pos, int &discounted_pos, int &discount_power
anton's avatar
anton committed
30
) {
anton's avatar
anton committed
31
32
33
    change_pos = group_of_thread * stride_cur_group + thread_in_group + stride_prev_group;
    discounted_pos = group_of_thread * stride_cur_group + stride_prev_group - 1;
    discount_power = thread_in_group + 1;
34
35
36
}


anton's avatar
anton committed
37
38
39
40
41
42
43
44
45
46
47
48
49
template <>
__device__ __forceinline__
void resolve_positions<SUM_DIRECTION_RIGHT>(
    const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
    int &change_pos, int &discounted_pos, int &discount_power
) {
    change_pos = group_of_thread * stride_cur_group + thread_in_group;
    discounted_pos = group_of_thread * stride_cur_group + stride_prev_group;
    discount_power = stride_prev_group - thread_in_group;
}


template <typename scalar_t, SumDirection sum_direction>
anton's avatar
anton committed
50
51
__global__
void discounted_cumsum_kernel_stage(
anton's avatar
anton committed
52
53
54
    torch::PackedTensorAccessor32<scalar_t, 2> x,
    const scalar_t gamma,
    int stage
55
56
) {
    const int len = x.size(1);
anton's avatar
anton committed
57
58
    const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
    const int thread_idy = blockIdx.y * blockDim.y + threadIdx.y;
59

anton's avatar
anton committed
60
    if (thread_idy >= x.size(0)) {
61
62
63
        return;
    }

anton's avatar
anton committed
64
65
    int stride_prev_group = 1 << stage;
    int stride_cur_group = stride_prev_group << 1;
66

anton's avatar
anton committed
67
68
    int group_of_thread = thread_idx >> stage;
    int thread_in_group = thread_idx - (group_of_thread << stage);
anton's avatar
anton committed
69
70

    int change_pos, discounted_pos, discount_power;
anton's avatar
anton committed
71
72
    resolve_positions<sum_direction>(
        stride_prev_group, stride_cur_group, group_of_thread, thread_in_group,
anton's avatar
anton committed
73
74
        change_pos, discounted_pos, discount_power
    );
75
76
77
78
79

    if (change_pos >= len || discounted_pos >= len) {
        return;
    }

anton's avatar
anton committed
80
81
82
    x[thread_idy][change_pos] = discounted_sum_power(
        x[thread_idy][change_pos],
        x[thread_idy][discounted_pos],
83
84
85
86
87
88
        gamma,
        discount_power
    );
}


anton's avatar
anton committed
89
90
91
92
93
94
inline
int log2ceil(int x) {
    return (int)ceil(log2((float)x));
}


anton's avatar
anton committed
95
template <SumDirection sum_direction>
anton's avatar
anton committed
96
torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
anton's avatar
anton committed
97
98
    // Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
    // Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
99
100
101
102
103
104
105
106
107
108
109
110

    TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor");
    TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
    TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
    TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");

    if (x.size(1) == 0) {
        return x;
    }

    auto y = x.clone();

anton's avatar
anton committed
111
    const int threads = 64;
112
113
114
115
116
    const int nstages = log2ceil(x.size(1));
    const int threads_total_x = 1 << (nstages - 1);
    const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));

    for (int stage=0; stage<nstages; stage++) {
anton's avatar
anton committed
117
        AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_kernel_stage", ([&] {
anton's avatar
anton committed
118
            discounted_cumsum_kernel_stage<scalar_t, sum_direction><<<blocks, threads>>>(
119
120
121
122
123
124
125
126
127
                y.packed_accessor32<scalar_t, 2>(),
                scalar_t(gamma),
                stage
            );
        }));
    }

    return y;
}
anton's avatar
anton committed
128
129


anton's avatar
anton committed
130
131
torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
    return discounted_cumsum<SUM_DIRECTION_LEFT>(x, gamma);
anton's avatar
anton committed
132
133
134
}


anton's avatar
anton committed
135
136
137
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
    return discounted_cumsum<SUM_DIRECTION_RIGHT>(x, gamma);
}