transform_kernels.cu 16.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
#include "custom_cuda_layers.h"

#define rows_trans 16
#define cols_trans 16

template <typename T>
__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
{
    __shared__ T data_block[rows_trans * (cols_trans + 1)];

    int r = threadIdx.x / cols_trans;
    int c = threadIdx.x % cols_trans;

    int m = row_width / cols_trans;

    int i = blockIdx.x / m * rows_trans + r;
    int j = blockIdx.x % m * cols_trans + c;

    int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);

    for (int k = 0; k < rows_trans; k += row_stride)
        data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];

    __syncthreads();

    i = blockIdx.x % m * rows_trans + r;
    j = blockIdx.x / m * cols_trans + c;

    for (int k = 0; k < rows_trans; k += row_stride)
        out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
}

template <>
void Transpose<__half>(const __half* inp_mat,
                       __half* out_mat,
                       int rows,
                       int cols,
                       cudaStream_t stream)
{
    int threads = THREADS;

    Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
        inp_mat, out_mat, cols, rows);
}

template <>
void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream)
{
    int threads = THREADS;

    Transpose_Kernel<float><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>(
        inp_mat, out_mat, cols, rows);
}

template <typename T>
__global__ void transform_0213(T* output, const T* vals, int hidden_dim, int seq_length, int heads);

template <>
__global__ void transform_0213<float>(float* output,
                                      const float* vals,
                                      int hidden_dim,
                                      int seq_length,
                                      int heads)
{
    int d0_stride = hidden_dim * seq_length / 4;
    int d1_stride = hidden_dim / 4;
    int d2_stride = hidden_dim / heads / 4;

    int d0_out_stride = d0_stride;
    int d1_out_stride = d2_stride;
    int d2_out_stride = d2_stride * seq_length;

    int d0 = blockIdx.x;   // Batch
    int d1 = blockIdx.y;   // Sequence ID (0-127)
    int d2 = threadIdx.y;  // Head (0-11)
    int d3 = threadIdx.x;  // Values (groups of 4)

    const float4* vals_vec = reinterpret_cast<const float4*>(vals);
    float4* output_vec = reinterpret_cast<float4*>(output);

    float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
    output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
}

template <>
__global__ void transform_0213<__half>(__half* output,
                                       const __half* vals,
                                       int hidden_dim,
                                       int seq_length,
                                       int heads)
{
#if __CUDA_ARCH__ >= 700

    int d0_stride = hidden_dim * seq_length / 8;
    int d1_stride = hidden_dim / 8;
    int d2_stride = hidden_dim / heads / 8;

    int d0_out_stride = d0_stride;
    int d1_out_stride = d2_stride;
    int d2_out_stride = d2_stride * seq_length;

    int d0 = blockIdx.x;   // Batch
    int d1 = blockIdx.y;   // Sequence ID (0-127)
    int d2 = threadIdx.y;  // Head (0-11)
    int d3 = threadIdx.x;  // Values (groups of 4)

    float4 vals_arr[1];

    const float4* vals_vec = reinterpret_cast<const float4*>(vals);
    float4* output_vec = reinterpret_cast<float4*>(output);

    vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
    output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
#endif
}

template <>
void launch_transform_0213<float>(float* output,
                                  const float* vals,
                                  int batch_size,
                                  int seq_length,
                                  int hidden_dim,
                                  int heads,
                                  cudaStream_t stream)
{
    dim3 block_dim(hidden_dim / heads / 4, heads);
    dim3 grid_dim(batch_size, seq_length);
    transform_0213<float>
        <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads);
}

template <>
void launch_transform_0213<__half>(__half* output,
                                   const __half* vals,
                                   int batch_size,
                                   int seq_length,
                                   int hidden_dim,
                                   int heads,
                                   cudaStream_t stream)
{
    dim3 block_dim(hidden_dim / heads / 8, heads);
    dim3 grid_dim(batch_size, seq_length);
    transform_0213<__half>
        <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads);
}

// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
                                        const T* vals,
                                        const T* bias,
                                        int hidden_dim,
                                        int seq_length,
                                        int heads);

template <>
__global__ void bias_add_transform_0213<float>(float* output,
                                               const float* vals,
                                               const float* bias,
                                               int hidden_dim,
                                               int seq_length,
                                               int heads)
{
    int d0_stride = hidden_dim * seq_length / 4;
    int d1_stride = hidden_dim / 4;
    int d2_stride = hidden_dim / heads / 4;

    int d0_out_stride = d0_stride;
    int d1_out_stride = d2_stride;
    int d2_out_stride = d2_stride * seq_length;

    int d0 = blockIdx.x;   // Batch
    int d1 = blockIdx.y;   // Sequence ID (0-127)
    int cnt = blockIdx.z;  // Hidden count
    int d2 = threadIdx.y;  // Head (0-11)
    int d3 = threadIdx.x;  // Values (groups of 4)

    const float4* vals_vec = reinterpret_cast<const float4*>(vals);
    const float4* bias_vec = reinterpret_cast<const float4*>(bias);
    float4* output_vec = reinterpret_cast<float4*>(output);

    float4 inputs = vals_vec[d0 * d0_stride * gridDim.z + cnt * d1_stride +
                             d1 * d1_stride * gridDim.z + d2 * d2_stride + d3];
    float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];

    float4 outputs;
    outputs.x = inputs.x + biases.x;
    outputs.y = inputs.y + biases.y;
    outputs.z = inputs.z + biases.z;
    outputs.w = inputs.w + biases.w;

    output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
               d2 * d2_out_stride + d3] = outputs;
}

#define ATTN_H 3
#define MAX_SEQ_LINE 10

template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
                                                const __half* vals,
                                                const __half* bias,
                                                int hidden_dim,
                                                int seq_length,
                                                int heads)
{
#if __CUDA_ARCH__ >= 700
    __shared__ float4 in_data[3072];

    int d0_stride = hidden_dim * seq_length / 8;
    int d1_stride = hidden_dim / 8;
    int d2_stride = hidden_dim / heads / 8;
    int iteration_stride = d1_stride * blockDim.z;  // Hidden * 3 / 8
    int batch_stride = d0_stride * blockDim.z;      // Hidden * S * 3 / 8

    int d0_out_stride = d0_stride;
    int d1_out_stride = d2_stride;
    int d2_out_stride = d2_stride * seq_length;

    int d0 = blockIdx.x;    // Batch
    int d1 = blockIdx.y;    // Sequence ID (0-127)
    int cnt = threadIdx.z;  // blockIdx.z; // Hidden count
    int d2 = threadIdx.y;   // Head (0-11)
    int d3 = threadIdx.x;   // Values (groups of 4)

    float4 vals_arr[1];
    float4 bias_arr[1];
    float4 output_arr[1];
    __half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
    __half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
    __half2* output_half = reinterpret_cast<__half2*>(output_arr);

    const float4* vals_vec = reinterpret_cast<const float4*>(vals);
    const float4* bias_vec = reinterpret_cast<const float4*>(bias);
    float4* output_vec = reinterpret_cast<float4*>(output);

    int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
    int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
    bias_arr[0] = bias_vec[iter_index];
    for (int iter = 0; iter < 2; iter++) {
        int iter_id = iter * iteration_stride + iter_index;
        vals_arr[0] = vals_vec[input_offset + iter_id];

        output_half[0] = vals_half[0] + bias_half[0];
        output_half[1] = vals_half[1] + bias_half[1];
        output_half[2] = vals_half[2] + bias_half[2];
        output_half[3] = vals_half[3] + bias_half[3];

        in_data[iter_id] = output_arr[0];
    }
    __syncthreads();

    iteration_stride = blockDim.z * (blockDim.y >> 1);
    int matrix_stride = (d0_out_stride * gridDim.x);
    int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);

    int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
    for (int iter = 0; iter < 2; iter++) {
        int iter_row = (iter * iteration_stride) + head_count;
        int iter_offset =
            (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
        output_vec[out_index + iter_offset] =
            in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
    }
#endif
}

// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
                                           const float* vals,
                                           const float* bias,
                                           int batch_size,
                                           int seq_length,
                                           int hidden_dim,
                                           int heads,
                                           cudaStream_t stream,
                                           int trans_count)
{
    dim3 block_dim(hidden_dim / heads / 4, heads);
    dim3 grid_dim(batch_size, seq_length, trans_count);
    bias_add_transform_0213<float>
        <<<grid_dim, block_dim, 0, stream>>>(output, vals, bias, hidden_dim, seq_length, heads);
}

template <>
void launch_bias_add_transform_0213<__half>(__half* output,
                                            const __half* vals,
                                            const __half* bias,
                                            int batch_size,
                                            int seq_length,
                                            int hidden_dim,
                                            int heads,
                                            cudaStream_t stream,
                                            int trans_count)
{
    dim3 block_dim(hidden_dim / heads / 8, heads, trans_count);
    dim3 grid_dim(batch_size, seq_length / 2);
    bias_add_transform_0213<__half>
        <<<grid_dim, block_dim, 0, stream>>>(output, vals, bias, hidden_dim, seq_length, heads);
}

template <typename T>
__global__ void transform4d_0213(T* out, const T* in, int heads, int seq_length, int hidden_dim);

template <>
__global__ void transform4d_0213<float>(float* out,
                                        const float* in,
                                        int heads,
                                        int seq_length,
                                        int hidden_dim)
{
    int d0_stride = hidden_dim * seq_length / 4;
    int d1_stride = d0_stride / heads;
    int d2_stride = hidden_dim / heads / 4;

    int d0_out_stride = d0_stride;
    int d1_out_stride = d2_stride;
    int d2_out_stride = hidden_dim / 4;

    int d0 = blockIdx.x;                                                 // Batch
    int d1 = blockIdx.y / ((seq_length + blockDim.y - 1) / blockDim.y);  // Head
    int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
    int cnt = blockIdx.z;
    int d3 = threadIdx.x;  // Values (groups of 8)

    if (d2 < seq_length) {
        const float4* in_vec = reinterpret_cast<const float4*>(in);
        float4* out_vec = reinterpret_cast<float4*>(out);

        float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
                                 d2 * d2_stride + d3];
        out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
                d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
    }
}

template <>
__global__ void transform4d_0213<__half>(__half* out,
                                         const __half* in,
                                         int heads,
                                         int seq_length,
                                         int hidden_dim)
{
#if __CUDA_ARCH__ >= 700
    __shared__ float4 in_data[3072];

    int d0_stride = hidden_dim * seq_length / 8;
    int d1_stride = hidden_dim / 8;
    int d2_stride = hidden_dim / heads / 8;

    int d0 = blockIdx.x;    // Batch
    int d1 = threadIdx.y;   // Head
    int d2 = blockIdx.y;    // Sequence
    int cnt = threadIdx.z;  // Hidden count
    int d3 = threadIdx.x;   // Values (groups of 8)

    const float4* in_vec = reinterpret_cast<const float4*>(in);
    float4* out_vec = reinterpret_cast<float4*>(out);

    int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + d1 % 2 * d2_stride;
    int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
    int iteration_stride = blockDim.z * (blockDim.y >> 1);
    int matrix_stride = (d0_stride * gridDim.x);

    for (int iter = 0; iter < 2; iter++) {
        int iter_row = iter * iteration_stride + head_count;
        int iter_offset = (iter_row % blockDim.y) * d2_stride;

        in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
            in_vec[input_offset + iter_offset * seq_length +
                   (iter_row / blockDim.y) * matrix_stride];
    }
    __syncthreads();

    iteration_stride = d1_stride * blockDim.z;
    int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
    int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);

    for (int iter = 0; iter < 2; iter++) {
        int iter_id = iter * iteration_stride + iter_index;
        out_vec[output_offset + iter_id] = in_data[iter_id];
    }
#endif
}

// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
                                    const float* in,
                                    int batch_size,
                                    int heads,
                                    int seq_length,
                                    int hidden_dim,
                                    cudaStream_t stream,
                                    int trans_count)
{
    dim3 grid_dims(batch_size, heads * ((seq_length + 7) / 8), trans_count);
    dim3 block_dims(hidden_dim / heads / 4, 8);
    transform4d_0213<float>
        <<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim);
}

template <>
void launch_transform4d_0213<__half>(__half* out,
                                     const __half* in,
                                     int batch_size,
                                     int heads,
                                     int seq_length,
                                     int hidden_dim,
                                     cudaStream_t stream,
                                     int trans_count)
{
    dim3 grid_dims(batch_size, seq_length / 2);
    dim3 block_dims(hidden_dim / heads / 8, heads, trans_count);
    transform4d_0213<__half>
        <<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim);
}