layernorm.cu 14.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include "util.h"

template <int Dim_, int VecSize_, int BatchesPerBlock_, int WarpsForOneBatchPerBlock_>
struct LNParameters {
    static constexpr int Dim = Dim_;
    static constexpr int VecSize = VecSize_;
    static constexpr int WarpSize = 32;
    static constexpr int BatchesPerBlock = BatchesPerBlock_;
    static constexpr int WarpStride = WarpSize * VecSize;
    static constexpr int WarpsForOneBatchPerBlock = WarpsForOneBatchPerBlock_;
    static constexpr int Iterations = Dim / WarpStride / WarpsForOneBatchPerBlock;
    static constexpr int BatchStride = Dim / WarpsForOneBatchPerBlock;
    static constexpr int ThreadsPerBlock = BatchesPerBlock * WarpSize * WarpsForOneBatchPerBlock;
    static_assert(Dim == WarpsForOneBatchPerBlock * WarpStride * Iterations, "");
    static_assert(Dim == BatchStride * WarpsForOneBatchPerBlock, "");
};

template <typename IndexType, typename input_t, typename output_t, typename acc_t, typename Parameters>
__global__ void layernorm_forward(output_t *dst, const input_t *src, const input_t *gamma, const input_t *beta,
    acc_t *mean, acc_t *invvar, IndexType bsz, acc_t epsilon) {
    static_assert(Parameters::WarpsForOneBatchPerBlock == 1, "");
    IndexType batch = blockIdx.x * Parameters::BatchesPerBlock + threadIdx.y;
    if (batch < bsz) {
        src += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
        dst += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
        gamma += threadIdx.x * Parameters::VecSize;
        beta += threadIdx.x * Parameters::VecSize;
        using VecInType = VecType<input_t, Parameters::VecSize>;
        VecInType elements[Parameters::Iterations];
        VecInType gamma_reg[Parameters::Iterations];
        VecInType beta_reg[Parameters::Iterations];
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations; ++i) {
            elements[i] = *(VecInType *)(src + i * Parameters::WarpStride);
            gamma_reg[i] = *(VecInType *)(gamma + i * Parameters::WarpStride);
            beta_reg[i] = *(VecInType *)(beta + i * Parameters::WarpStride);
        }
        input_t *elements_l = (input_t *)elements;
        input_t *gamma_l = (input_t *)gamma_reg;
        input_t *beta_l = (input_t *)beta_reg;
        
        acc_t sum = 0.0;
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
            sum += (acc_t)elements_l[i];
        }
        #pragma unroll
        for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
            sum += SHFL_XOR(sum, offset, Parameters::WarpSize);
        }
        
        acc_t mu = sum / Parameters::Dim;
        acc_t var = 0.0;
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
            acc_t diff = (acc_t)elements_l[i] - mu;
            var += diff * diff;
        }
        #pragma unroll
        for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
            var += SHFL_XOR(var, offset, Parameters::WarpSize);
        }
        const acc_t rsigma = rsqrtf(var / Parameters::Dim + epsilon);
        
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
            elements_l[i] = (input_t)(((acc_t)elements_l[i] - mu) * rsigma) * gamma_l[i] + beta_l[i];
        }
        
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations; ++i) {
            *(VecInType *)(dst + i * Parameters::WarpStride) = elements[i];
        }
        
        if (threadIdx.x == 0) {
            mean[batch] = mu;
            invvar[batch] = rsigma;
        }
    }
}

template <typename IndexType, typename input_t, typename output_t, typename acc_t, typename Parameters>
__global__ void layernorm_backward_x(output_t *dst, const input_t *input, const input_t *grad, const input_t *gamma,
    const acc_t *mean, const acc_t *invvar, IndexType bsz) {
    IndexType batch = blockIdx.x * Parameters::BatchesPerBlock + threadIdx.y;
    if (batch < bsz) {
        input += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
        dst += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
        grad += batch * Parameters::Dim + threadIdx.x * Parameters::VecSize;
        gamma += threadIdx.x * Parameters::VecSize;
        using VecInType = VecType<input_t, Parameters::VecSize>;
        VecInType elements[Parameters::Iterations];
        VecInType grad_reg[Parameters::Iterations];
        VecInType gamma_reg[Parameters::Iterations];
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations; ++i) {
            elements[i] = *(VecInType *)(input + i * Parameters::WarpStride);
            grad_reg[i] = *(VecInType *)(grad + i * Parameters::WarpStride);
            gamma_reg[i] = *(VecInType *)(gamma + i * Parameters::WarpStride);
        }
        input_t *elements_l = (input_t *)elements;
        input_t *grad_l = (input_t *)grad_reg;
        input_t *gamma_l = (input_t *)gamma_reg;
        const acc_t mu = mean[batch];
        const acc_t var = invvar[batch];
        
        acc_t sum1 = 0.0, sum2 = 0.0;
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
            elements_l[i] = elements_l[i] - (input_t)mu;
            sum1 += (acc_t)(elements_l[i] * grad_l[i] * gamma_l[i]);
            sum2 += (acc_t)(grad_l[i] * gamma_l[i]);
        }
        
        #pragma unroll
        for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
            sum1 += SHFL_XOR(sum1, offset, Parameters::WarpSize);
        }
        
        #pragma unroll
        for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2) {
            sum2 += SHFL_XOR(sum2, offset, Parameters::WarpSize);
        }
        
        sum1 *= var * var * var / Parameters::Dim;
        sum2 *= var / Parameters::Dim;
        
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations * Parameters::VecSize; ++i) {
            elements_l[i] = grad_l[i] * gamma_l[i] * (input_t)var - (input_t)sum1 * elements_l[i] - (input_t)sum2;
        }
        
        #pragma unroll
        for (int i = 0; i < Parameters::Iterations; ++i) {
            *(VecInType *)(dst + i * Parameters::WarpStride) = elements[i];
        }
    }
}

#define LAUNCH_FORWARD_KERNEL(len, vec, batches, type) \
{ \
    dim3 threads(32, batches); \
    int blocks = DIV_CELL(n1, batches); \
    layernorm_forward<size_t, type, type, float, LNParameters<len, vec, batches, 1>> \
    <<<blocks, threads, 0, stream>>> \
    ((type *)output->data_ptr(), (type *)input->data_ptr(), (type *)gamma->data_ptr(), \
        (type *)beta->data_ptr(), (float *)mean->data_ptr(), (float *)invvar->data_ptr(), n1, epsilon); \
    break; \
}

#define LAUNCH_BACKWARD_KERNEL(len, vec, batches, type) \
{ \
    dim3 threads(32, batches); \
    int blocks = DIV_CELL(n1, batches); \
    layernorm_backward_x<size_t, type, type, float, LNParameters<len, vec, batches, 1>> \
    <<<blocks, threads, 0, stream>>> \
    ((type *)grad_input->data_ptr(), (type *)input->data_ptr(), (type *)dout->data_ptr(), \
        (type *)gamma->data_ptr(), (float *)mean->data_ptr(), (float *)invvar->data_ptr(), n1); \
    break; \
}

void cuda_layer_norm(
    at::Tensor* output,
    at::Tensor* mean,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
    at::IntArrayRef normalized_shape,
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon)
{
    using namespace at;
    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
    auto type = input->scalar_type();
    
    if (type == at::ScalarType::BFloat16) {
        switch (n2) {
        case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, nv_bfloat16)
        case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
189
        case 192: LAUNCH_FORWARD_KERNEL(192, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
190
        case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
191
        case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
192
193
        case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, nv_bfloat16)
        case 512: LAUNCH_FORWARD_KERNEL(512, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
194
        case 640: LAUNCH_FORWARD_KERNEL(640, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
195
196
197
198
199
200
        case 768: LAUNCH_FORWARD_KERNEL(768, 2, 4, nv_bfloat16)
        case 1024: LAUNCH_FORWARD_KERNEL(1024, 2, 4, nv_bfloat16)
        case 1280: LAUNCH_FORWARD_KERNEL(1280, 2, 4, nv_bfloat16)
        case 1536: LAUNCH_FORWARD_KERNEL(1536, 2, 4, nv_bfloat16)
        case 1792: LAUNCH_FORWARD_KERNEL(1792, 2, 4, nv_bfloat16)
        case 2048: LAUNCH_FORWARD_KERNEL(2048, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
201
202
        case 2560: LAUNCH_FORWARD_KERNEL(2560, 2, 4, nv_bfloat16)
        case 5120: LAUNCH_FORWARD_KERNEL(5120, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
203
204
205
206
207
        }
    } else if (type == at::ScalarType::Half) {
        switch (n2) {
        case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, half)
        case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
208
        case 192: LAUNCH_FORWARD_KERNEL(192, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
209
        case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
210
        case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
211
212
        case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, half)
        case 512: LAUNCH_FORWARD_KERNEL(512, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
213
        case 640: LAUNCH_FORWARD_KERNEL(640, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
214
215
216
217
218
219
        case 768: LAUNCH_FORWARD_KERNEL(768, 2, 4, half)
        case 1024: LAUNCH_FORWARD_KERNEL(1024, 2, 4, half)
        case 1280: LAUNCH_FORWARD_KERNEL(1280, 2, 4, half)
        case 1536: LAUNCH_FORWARD_KERNEL(1536, 2, 4, half)
        case 1792: LAUNCH_FORWARD_KERNEL(1792, 2, 4, half)
        case 2048: LAUNCH_FORWARD_KERNEL(2048, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
220
221
        case 2560: LAUNCH_FORWARD_KERNEL(2560, 2, 4, half)
        case 5120: LAUNCH_FORWARD_KERNEL(5120, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
222
223
224
225
226
        }
    } else if (type == at::ScalarType::Float) {
        switch (n2) {
        case 64: LAUNCH_FORWARD_KERNEL(64, 1, 4, float)
        case 128: LAUNCH_FORWARD_KERNEL(128, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
227
        case 192: LAUNCH_FORWARD_KERNEL(192, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
228
        case 256: LAUNCH_FORWARD_KERNEL(256, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
229
        case 320: LAUNCH_FORWARD_KERNEL(320, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
230
231
        case 384: LAUNCH_FORWARD_KERNEL(384, 1, 4, float)
        case 512: LAUNCH_FORWARD_KERNEL(512, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
232
        case 640: LAUNCH_FORWARD_KERNEL(640, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
233
234
235
236
237
238
        case 768: LAUNCH_FORWARD_KERNEL(768, 1, 4, float)
        case 1024: LAUNCH_FORWARD_KERNEL(1024, 1, 4, float)
        case 1280: LAUNCH_FORWARD_KERNEL(1280, 1, 4, float)
        case 1536: LAUNCH_FORWARD_KERNEL(1536, 1, 4, float)
        case 1792: LAUNCH_FORWARD_KERNEL(1792, 1, 4, float)
        case 2048: LAUNCH_FORWARD_KERNEL(2048, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
239
240
        case 2560: LAUNCH_FORWARD_KERNEL(2560, 1, 4, float)
        case 5120: LAUNCH_FORWARD_KERNEL(5120, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        }
    }
}

void cuda_layer_norm_gradient(
    at::Tensor* dout,
    at::Tensor* mean,
    at::Tensor* invvar,
    at::Tensor* input,
    int n1,
    int n2,
    at::IntArrayRef normalized_shape,
    at::Tensor* gamma,
    at::Tensor* beta,
    double epsilon,
    at::Tensor* grad_input)
{   
    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
    auto type = input->scalar_type();
    
    if (type == at::ScalarType::BFloat16) {
        switch (n2) {
        case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, nv_bfloat16)
        case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
265
        case 192: LAUNCH_BACKWARD_KERNEL(192, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
266
        case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
267
        case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
268
269
        case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, nv_bfloat16)
        case 512: LAUNCH_BACKWARD_KERNEL(512, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
270
        case 640: LAUNCH_BACKWARD_KERNEL(640, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
271
272
273
274
275
276
        case 768: LAUNCH_BACKWARD_KERNEL(768, 2, 4, nv_bfloat16)
        case 1024: LAUNCH_BACKWARD_KERNEL(1024, 2, 4, nv_bfloat16)
        case 1280: LAUNCH_BACKWARD_KERNEL(1280, 2, 4, nv_bfloat16)
        case 1536: LAUNCH_BACKWARD_KERNEL(1536, 2, 4, nv_bfloat16)
        case 1792: LAUNCH_BACKWARD_KERNEL(1792, 2, 4, nv_bfloat16)
        case 2048: LAUNCH_BACKWARD_KERNEL(2048, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
277
278
        case 2560: LAUNCH_BACKWARD_KERNEL(2560, 2, 4, nv_bfloat16)
        case 5120: LAUNCH_BACKWARD_KERNEL(5120, 2, 4, nv_bfloat16)
Guolin Ke's avatar
Guolin Ke committed
279
280
281
282
283
        }
    } else if (type == at::ScalarType::Half) {
        switch (n2) {
        case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, half)
        case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
284
        case 192: LAUNCH_BACKWARD_KERNEL(192, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
285
        case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
286
        case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
287
288
        case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, half)
        case 512: LAUNCH_BACKWARD_KERNEL(512, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
289
        case 640: LAUNCH_BACKWARD_KERNEL(640, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
290
291
292
293
294
295
        case 768: LAUNCH_BACKWARD_KERNEL(768, 2, 4, half)
        case 1024: LAUNCH_BACKWARD_KERNEL(1024, 2, 4, half)
        case 1280: LAUNCH_BACKWARD_KERNEL(1280, 2, 4, half)
        case 1536: LAUNCH_BACKWARD_KERNEL(1536, 2, 4, half)
        case 1792: LAUNCH_BACKWARD_KERNEL(1792, 2, 4, half)
        case 2048: LAUNCH_BACKWARD_KERNEL(2048, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
296
297
        case 2560: LAUNCH_BACKWARD_KERNEL(2560, 2, 4, half)
        case 5120: LAUNCH_BACKWARD_KERNEL(5120, 2, 4, half)
Guolin Ke's avatar
Guolin Ke committed
298
299
300
301
302
        }
    } else if (type == at::ScalarType::Float) {
        switch (n2) {
        case 64: LAUNCH_BACKWARD_KERNEL(64, 1, 4, float)
        case 128: LAUNCH_BACKWARD_KERNEL(128, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
303
        case 192: LAUNCH_BACKWARD_KERNEL(192, 2, 4, float)
Guolin Ke's avatar
Guolin Ke committed
304
        case 256: LAUNCH_BACKWARD_KERNEL(256, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
305
        case 320: LAUNCH_BACKWARD_KERNEL(320, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
306
307
        case 384: LAUNCH_BACKWARD_KERNEL(384, 1, 4, float)
        case 512: LAUNCH_BACKWARD_KERNEL(512, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
308
        case 640: LAUNCH_BACKWARD_KERNEL(640, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
313
314
        case 768: LAUNCH_BACKWARD_KERNEL(768, 1, 4, float)
        case 1024: LAUNCH_BACKWARD_KERNEL(1024, 1, 4, float)
        case 1280: LAUNCH_BACKWARD_KERNEL(1280, 1, 4, float)
        case 1536: LAUNCH_BACKWARD_KERNEL(1536, 1, 4, float)
        case 1792: LAUNCH_BACKWARD_KERNEL(1792, 1, 4, float)
        case 2048: LAUNCH_BACKWARD_KERNEL(2048, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
315
316
        case 2560: LAUNCH_BACKWARD_KERNEL(2560, 1, 4, float)
        case 5120: LAUNCH_BACKWARD_KERNEL(5120, 1, 4, float)
Guolin Ke's avatar
Guolin Ke committed
317
318
319
        }
    }
}