softmax.cu 21.9 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/

aiss's avatar
aiss committed
5
#include <limits>
aiss's avatar
aiss committed
6
#include "inference_cuda_layers.h"
aiss's avatar
aiss committed
7
8
9
10
11
12
13
14

#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <ctime>

aiss's avatar
aiss committed
15
#define ATTN_THREADS 256
aiss's avatar
aiss committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#define MAX_REG_SIZE 8

#define minus_infinity -10000.0

void CheckCudaErrorAux(const char* file, unsigned line)
{
    cudaError_t err = cudaGetLastError();
    if (err == cudaSuccess) return;
    std::cerr << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line
              << std::endl;
    throw std::runtime_error("CUDA ERROR!!!\n");
}

#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)

namespace cg = cooperative_groups;

__global__ void attn_softmax_v2(__half* vals,
                                __half* mask,
aiss's avatar
aiss committed
35
36
                                __half* alibi,
                                float layer_scale,
aiss's avatar
aiss committed
37
38
39
40
41
42
43
44
                                bool triangular,
                                bool recompute,
                                bool local_attention,
                                int window_size,
                                int total_count,
                                int heads,
                                int sequence_length,
                                int num_seq,
aiss's avatar
aiss committed
45
46
47
                                int head_offset,
                                int mask_stride,
                                int mp_size,
aiss's avatar
aiss committed
48
49
50
51
52
53
54
55
                                int iterations,
                                int reduceWidth)
{
    cg::thread_block b = cg::this_thread_block();
    cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);

    float2 low_data[MAX_REG_SIZE];
    float2 high_data[MAX_REG_SIZE];
aiss's avatar
aiss committed
56
    const __half zero_h = __float2half(0.f);
aiss's avatar
aiss committed
57
58
59
60
61
62
63
64
65
66
67

    int wid = threadIdx.x >> 5;
    int lane = threadIdx.x & 0x1f;
    int warp_num = blockDim.x >> 5;

    int reduce_blocks = reduceWidth >> 5;
    int seq_lane = threadIdx.x % reduceWidth;

    __shared__ float partialSum[MAX_WARP_NUM];

    int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
aiss's avatar
aiss committed
68
69
70
    int batch_idx = iter_offset / (num_seq * heads);
    int alibi_offset = batch_idx * heads * mp_size + head_offset;
    int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
aiss's avatar
aiss committed
71
72
73
74

    if (iter_offset < total_count) {
        vals += (iter_offset * sequence_length);

aiss's avatar
aiss committed
75
76
        alibi_offset = (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length;
        mask_offset = mask_offset * sequence_length;
aiss's avatar
aiss committed
77
78
79
80
81
82
83
84
85
86
87
        int seq_id = iter_offset % num_seq;
        int seq_id4 = seq_id >> 2;

        int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
        int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
                                 ? (real_seq_id >> 2) - (window_size >> 2)
                                 : 0;
        int window_stride =
            (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;

        float max_val = minus_infinity;
aiss's avatar
aiss committed
88
        // if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset);
aiss's avatar
aiss committed
89
90
91
92
93
        for (int i = 0; i < iterations; i++) {
            int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
            if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
                data_id < sequence_length) {
                if ((sequence_length - data_id) >= 4) {
aiss's avatar
aiss committed
94
95
96
                    low_data[i].x = data_id > window_stride
                                        ? __half2float(vals[data_id]) * layer_scale
                                        : minus_infinity;
aiss's avatar
aiss committed
97
98
                    low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
                                     (data_id + 1) > window_stride)
aiss's avatar
aiss committed
99
                                        ? __half2float(vals[data_id + 1]) * layer_scale
aiss's avatar
aiss committed
100
101
102
                                        : minus_infinity;
                    high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
                                      (data_id + 2) > window_stride)
aiss's avatar
aiss committed
103
                                         ? __half2float(vals[data_id + 2]) * layer_scale
aiss's avatar
aiss committed
104
105
106
                                         : minus_infinity;
                    high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
                                      (data_id + 3) > window_stride)
aiss's avatar
aiss committed
107
                                         ? __half2float(vals[data_id + 3]) * layer_scale
aiss's avatar
aiss committed
108
                                         : minus_infinity;
aiss's avatar
aiss committed
109
110
111
112
113
114
115
116
117
118
                    if (alibi) {
                        low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]);
                        low_data[i].y =
                            low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]);
                        high_data[i].x =
                            high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]);
                        high_data[i].y =
                            high_data[i].y + __half2float(alibi[data_id + alibi_offset + 3]);
                    }
                    if (mask) {
aiss's avatar
aiss committed
119
120
121
122
123
124
                        low_data[i].x += __half2float(mask[data_id + mask_offset]);
                        low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
                        high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
                        high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
                    }
                } else {
aiss's avatar
aiss committed
125
126
127
                    low_data[i].x = data_id > window_stride
                                        ? __half2float(vals[data_id]) * layer_scale
                                        : minus_infinity;
aiss's avatar
aiss committed
128
129
130
                    low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
                                      (data_id + 1) > window_stride) &&
                                     (data_id + 1) < sequence_length)
aiss's avatar
aiss committed
131
                                        ? __half2float(vals[data_id + 1]) * layer_scale
aiss's avatar
aiss committed
132
133
134
135
                                        : minus_infinity;
                    high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
                                       (data_id + 2) > window_stride) &&
                                      (data_id + 2) < sequence_length)
aiss's avatar
aiss committed
136
                                         ? __half2float(vals[data_id + 2]) * layer_scale
aiss's avatar
aiss committed
137
                                         : minus_infinity;
aiss's avatar
aiss committed
138
139
140
141
142
143
144
145
146
                    if (alibi) {
                        low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]);
                        if ((data_id + 1) < sequence_length)
                            low_data[i].y =
                                low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]);
                        if ((data_id + 2) < sequence_length)
                            high_data[i].x =
                                high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]);
                    }
aiss's avatar
aiss committed
147
                    high_data[i].y = minus_infinity;
aiss's avatar
aiss committed
148
                    if (mask) {
aiss's avatar
aiss committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
                        low_data[i].x += __half2float(mask[data_id + mask_offset]);
                        if ((data_id + 1) < sequence_length)
                            low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
                        if ((data_id + 2) < sequence_length)
                            high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
                    }
                }
                // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
                max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
                max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
                max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
                max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
            } else {
                low_data[i].x = minus_infinity;
                low_data[i].y = minus_infinity;
                high_data[i].x = minus_infinity;
                high_data[i].y = minus_infinity;
            }
        }

        for (int i = 1; i < WARP_SIZE; i *= 2) {
            auto temp = g.shfl_xor(max_val, i);
            max_val = (temp > max_val ? temp : max_val);
        }

        if (reduceWidth > WARP_SIZE) {
            if (lane == 0) partialSum[wid] = max_val;
            b.sync();

            if (lane < warp_num) max_val = partialSum[lane];

            b.sync();

            for (int i = 1; i < reduce_blocks; i *= 2) {
                auto temp = g.shfl_xor(max_val, i);
                max_val = (temp > max_val ? temp : max_val);
            }

            max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
        }
        float sum = 0;
        for (int i = 0; i < iterations; i++) {
            low_data[i].x = __expf(low_data[i].x - max_val);
            low_data[i].y = __expf(low_data[i].y - max_val);
            high_data[i].x = __expf(high_data[i].x - max_val);
            high_data[i].y = __expf(high_data[i].y - max_val);

            sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
        }

        for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);

        if (reduceWidth > WARP_SIZE) {
            if (lane == 0) partialSum[wid] = sum;
            b.sync();

            if (lane < warp_num) sum = partialSum[lane];

            b.sync();

            for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }

            sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
        }
        sum += 1e-6;
        for (int i = 0; i < iterations; i++) {
            int data_id = i * (reduceWidth << 2) + (seq_lane << 2);

            if (data_id < sequence_length) {
                if ((sequence_length - data_id) >= 4) {
aiss's avatar
aiss committed
219
220
221
222
                    vals[data_id] = __float2half(low_data[i].x / sum);
                    vals[data_id + 1] = __float2half(low_data[i].y / sum);
                    vals[data_id + 2] = __float2half(high_data[i].x / sum);
                    vals[data_id + 3] = __float2half(high_data[i].y / sum);
aiss's avatar
aiss committed
223
                } else {
aiss's avatar
aiss committed
224
225
226
227
228
                    vals[data_id] = __float2half(low_data[i].x / sum);
                    if ((data_id + 1) < sequence_length)
                        vals[data_id + 1] = __float2half(low_data[i].y / sum);
                    if ((data_id + 2) < sequence_length)
                        vals[data_id + 2] = __float2half(high_data[i].x / sum);
aiss's avatar
aiss committed
229
230
231
232
233
234
235
236
                }
            }
        }
    }
}

__global__ void attn_softmax_v2(float* vals,
                                float* attn_mask,
aiss's avatar
aiss committed
237
238
                                float* alibi,
                                float layer_scale,
aiss's avatar
aiss committed
239
240
241
242
243
244
245
246
                                bool triangular,
                                bool recompute,
                                bool local_attention,
                                int window_size,
                                int total_count,
                                int heads,
                                int sequence_length,
                                int num_seq,
aiss's avatar
aiss committed
247
248
249
                                int head_offset,
                                int mask_stride,
                                int mp_size,
aiss's avatar
aiss committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                                int iterations,
                                int reduceWidth)
{
    cg::thread_block b = cg::this_thread_block();
    cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);

    float4 data[MAX_REG_SIZE];

    int wid = threadIdx.x >> 5;
    int lane = threadIdx.x & 0x1f;
    int warp_num = blockDim.x >> 5;

    int reduce_blocks = reduceWidth >> 5;
    int seq_lane = threadIdx.x % reduceWidth;

    __shared__ float partialSum[MAX_WARP_NUM];

    int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
    if (iter_offset < total_count) {
        vals += (iter_offset * sequence_length);

aiss's avatar
aiss committed
271
272
273
274
        int batch_idx = iter_offset / (num_seq * heads);
        int alibi_offset = batch_idx * heads * mp_size + head_offset;
        int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
        mask_offset = mask_offset * sequence_length;
aiss's avatar
aiss committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        int seq_id = iter_offset % num_seq;
        int seq_id4 = seq_id >> 2;

        int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
        int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
                                 ? (real_seq_id >> 2) - (window_size >> 2)
                                 : 0;
        int window_stride =
            (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;

        float max_val = minus_infinity;

        for (int i = 0; i < iterations; i++) {
            int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
            if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
                data_id < sequence_length) {
                if ((sequence_length - data_id) >= 4) {
                    data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
                    data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
                                 (data_id + 1) > window_stride)
                                    ? vals[data_id + 1]
                                    : minus_infinity;
                    data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
                                 (data_id + 2) > window_stride)
                                    ? vals[data_id + 2]
                                    : minus_infinity;
                    data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
                                 (data_id + 3) > window_stride)
                                    ? vals[data_id + 3]
                                    : minus_infinity;
aiss's avatar
aiss committed
305
                    if (attn_mask) {
aiss's avatar
aiss committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                        data[i].x += attn_mask[data_id + mask_offset];
                        data[i].y += attn_mask[data_id + mask_offset + 1];
                        data[i].z += attn_mask[data_id + mask_offset + 2];
                        data[i].w += attn_mask[data_id + mask_offset + 3];
                    }
                } else {
                    data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
                    data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
                                 (data_id + 1) > window_stride && (data_id + 1) < sequence_length)
                                    ? (vals[data_id + 1])
                                    : minus_infinity;
                    data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
                                 (data_id + 2) > window_stride && (data_id + 2) < sequence_length)
                                    ? (vals[data_id + 2])
                                    : minus_infinity;
                    data[i].w = minus_infinity;
aiss's avatar
aiss committed
322
                    if (attn_mask) {
aiss's avatar
aiss committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
                        data[i].x += attn_mask[data_id + mask_offset];
                        if ((data_id + 1) < sequence_length)
                            data[i].y += attn_mask[data_id + mask_offset + 1];
                        if ((data_id + 2) < sequence_length)
                            data[i].z += attn_mask[data_id + mask_offset + 2];
                    }
                }
                max_val = (data[i].x > max_val ? data[i].x : max_val);
                max_val = (data[i].y > max_val ? data[i].y : max_val);
                max_val = (data[i].z > max_val ? data[i].z : max_val);
                max_val = (data[i].w > max_val ? data[i].w : max_val);
            } else {
                data[i].x = minus_infinity;
                data[i].y = minus_infinity;
                data[i].z = minus_infinity;
                data[i].w = minus_infinity;
            }
        }

        for (int i = 1; i < WARP_SIZE; i *= 2) {
            auto temp = g.shfl_xor(max_val, i);
            max_val = (temp > max_val ? temp : max_val);
        }

        if (reduceWidth > WARP_SIZE) {
            if (lane == 0) partialSum[wid] = max_val;
            b.sync();

            if (lane < warp_num) max_val = partialSum[lane];

            b.sync();

            for (int i = 1; i < reduce_blocks; i *= 2) {
                auto temp = g.shfl_xor(max_val, i);
                max_val = (temp > max_val ? temp : max_val);
            }

            max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
        }

        float sum = 0;
        for (int i = 0; i < iterations; i++) {
            data[i].x = __expf(data[i].x - max_val);
            data[i].y = __expf(data[i].y - max_val);
            data[i].z = __expf(data[i].z - max_val);
            data[i].w = __expf(data[i].w - max_val);

            sum += (data[i].x + data[i].y + data[i].z + data[i].w);
        }

        for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);

        if (reduceWidth > WARP_SIZE) {
            if (lane == 0) partialSum[wid] = sum;
            b.sync();

            if (lane < warp_num) sum = partialSum[lane];

            b.sync();

            for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }

            sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
        }
        sum += 1e-6;

        for (int i = 0; i < iterations; i++) {
            int data_id = i * (reduceWidth << 2) + (seq_lane << 2);

            if (data_id < sequence_length) {
                if ((sequence_length - data_id) >= 4) {
                    vals[data_id] = data[i].x / sum;
                    vals[data_id + 1] = data[i].y / sum;
                    vals[data_id + 2] = data[i].z / sum;
                    vals[data_id + 3] = data[i].w / sum;
                } else {
                    vals[data_id] = data[i].x / sum;
                    if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
                    if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
                }
            }
        }
    }
}

template <typename T>
void launch_attn_softmax_v2(T* vals,
                            T* mask,
aiss's avatar
aiss committed
411
412
                            T* alibi,
                            float layer_scale,
aiss's avatar
aiss committed
413
414
415
416
417
418
419
420
                            bool triangular,
                            bool recompute,
                            bool local_attention,
                            int window_size,
                            int batch_size,
                            int heads,
                            int num_seq,
                            int sequence_length,
aiss's avatar
aiss committed
421
422
423
                            int head_offset,
                            int mask_stride,
                            int mp_size,
aiss's avatar
aiss committed
424
425
426
                            cudaStream_t stream)
{
    int total_count = batch_size * heads * num_seq;
aiss's avatar
aiss committed
427
428
429
430
    int warp_num = ATTN_THREADS / WARP_SIZE;
    int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1);
    reduce_width = (int)pow(2.0, floor(log2((float)(reduce_width)))) * WARP_SIZE;
    dim3 grid_dim((total_count - 1) / (ATTN_THREADS / reduce_width) + 1);
aiss's avatar
aiss committed
431
432
433
434
435
    dim3 block_dim(ATTN_THREADS);

    const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;

    if (sequence_length <= 32768)
aiss's avatar
aiss committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
        attn_softmax_v2<<<grid_dim, block_dim, 0, stream>>>(vals,
                                                            mask,
                                                            alibi,
                                                            layer_scale,
                                                            triangular,
                                                            recompute,
                                                            local_attention,
                                                            window_size,
                                                            total_count,
                                                            heads,
                                                            sequence_length,
                                                            num_seq,
                                                            head_offset,
                                                            mask_stride,
                                                            mp_size,
                                                            iterations,
                                                            reduce_width);
aiss's avatar
aiss committed
453
454
455
456
457
458
    else
        throw std::runtime_error("Unsupport Seq_Length!");
}

template void launch_attn_softmax_v2(float* vals,
                                     float* mask,
aiss's avatar
aiss committed
459
460
                                     float* alibi,
                                     float layer_scale,
aiss's avatar
aiss committed
461
462
463
464
465
466
467
468
                                     bool triangular,
                                     bool recompute,
                                     bool local_attention,
                                     int window_size,
                                     int batch_size,
                                     int heads,
                                     int num_seq,
                                     int sequence_length,
aiss's avatar
aiss committed
469
470
471
                                     int head_offset,
                                     int mask_stride,
                                     int mp_size,
aiss's avatar
aiss committed
472
473
474
                                     cudaStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
                                     __half* mask,
aiss's avatar
aiss committed
475
476
                                     __half* alibi,
                                     float layer_scale,
aiss's avatar
aiss committed
477
478
479
480
481
482
483
484
                                     bool triangular,
                                     bool recompute,
                                     bool local_attention,
                                     int window_size,
                                     int batch_size,
                                     int heads,
                                     int num_seq,
                                     int sequence_length,
aiss's avatar
aiss committed
485
486
487
                                     int head_offset,
                                     int mask_stride,
                                     int mp_size,
aiss's avatar
aiss committed
488
                                     cudaStream_t stream);