utils.cuh 11.2 KB
Newer Older
Muyang Li's avatar
Muyang Li committed
1
2
// Adated from FasterTransformer,
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
Zhekai Zhang's avatar
Zhekai Zhang committed
3
4
#pragma once

5
6
7
#include <cassert>
#include <cstdint>
#include <cfloat>
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9
#include <type_traits>

10
11
#include <cstdio>

Zhekai Zhang's avatar
Zhekai Zhang committed
12
13
#include <cuda_fp16.h>

14
15
16
17
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif

Muyang Li's avatar
Muyang Li committed
18
__device__ __forceinline__ static void trap_unsupported_arch() {
19
20
21
22
23
24
25
26
27
    if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
        printf("This kernel is not supported on your GPU\n");
    }
    __syncthreads();
    __nanosleep(1000000);
    __trap();
}

#if defined(ENABLE_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
Muyang Li's avatar
Muyang Li committed
28
29
__device__ __forceinline__ static __nv_bfloat162
__hfma2(const __nv_bfloat162 a, const __nv_bfloat162 b, const __nv_bfloat162 c) {
30
31
32
33
34
    trap_unsupported_arch();
    return __nv_bfloat162(0.0f, 0.0f);
}
#endif

Muyang Li's avatar
Muyang Li committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
template<typename T>
struct num_elems;
template<>
struct num_elems<float> {
    static constexpr int value = 1;
};
template<>
struct num_elems<float2> {
    static constexpr int value = 2;
};
template<>
struct num_elems<float4> {
    static constexpr int value = 4;
};
template<>
struct num_elems<half> {
    static constexpr int value = 1;
};
template<>
struct num_elems<half2> {
    static constexpr int value = 2;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
57
#ifdef ENABLE_BF16
Muyang Li's avatar
Muyang Li committed
58
59
60
61
62
63
64
65
template<>
struct num_elems<__nv_bfloat16> {
    static constexpr int value = 1;
};
template<>
struct num_elems<__nv_bfloat162> {
    static constexpr int value = 2;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
66
67
#endif
#ifdef ENABLE_FP8
Muyang Li's avatar
Muyang Li committed
68
69
70
71
72
73
74
75
template<>
struct num_elems<__nv_fp8_e4m3> {
    static constexpr int value = 1;
};
template<>
struct num_elems<__nv_fp8x2_e4m3> {
    static constexpr int value = 2;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
76
77
#endif

Muyang Li's avatar
Muyang Li committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
template<typename T, int num>
struct packed_as;
template<typename T>
struct packed_as<T, 1> {
    using type = T;
};
template<>
struct packed_as<half, 2> {
    using type = half2;
};
template<>
struct packed_as<float, 2> {
    using type = float2;
};
template<>
struct packed_as<int8_t, 2> {
    using type = int16_t;
};
template<>
struct packed_as<int32_t, 2> {
    using type = int2;
};
template<>
struct packed_as<half2, 1> {
    using type = half;
};
template<>
struct packed_as<float2, 1> {
    using type = float;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
108
#ifdef ENABLE_BF16
Muyang Li's avatar
Muyang Li committed
109
110
111
112
113
114
115
116
template<>
struct packed_as<__nv_bfloat16, 2> {
    using type = __nv_bfloat162;
};
template<>
struct packed_as<__nv_bfloat162, 1> {
    using type = __nv_bfloat16;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
117
118
#endif
#ifdef ENABLE_FP8
Muyang Li's avatar
Muyang Li committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
template<>
struct packed_as<__nv_fp8_e4m3, 2> {
    using type = __nv_fp8x2_e4m3;
};
template<>
struct packed_as<__nv_fp8x2_e4m3, 1> {
    using type = __nv_fp8_e4m3;
};
template<>
struct packed_as<__nv_fp8_e5m2, 2> {
    using type = __nv_fp8x2_e5m2;
};
template<>
struct packed_as<__nv_fp8x2_e5m2, 1> {
    using type = __nv_fp8_e5m2;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
135
136
#endif

Muyang Li's avatar
Muyang Li committed
137
138
139
140
141
142
143
144
145
inline __device__ float2 operator*(float2 a, float2 b) {
    return make_float2(a.x * b.x, a.y * b.y);
}
inline __device__ float2 operator+(float2 a, float2 b) {
    return make_float2(a.x + b.x, a.y + b.y);
}
inline __device__ float2 operator-(float2 a, float2 b) {
    return make_float2(a.x - b.x, a.y - b.y);
}
Zhekai Zhang's avatar
Zhekai Zhang committed
146

Muyang Li's avatar
Muyang Li committed
147
148
149
150
151
152
153
154
155
inline __device__ float2 operator*(float2 a, float b) {
    return make_float2(a.x * b, a.y * b);
}
inline __device__ float2 operator+(float2 a, float b) {
    return make_float2(a.x + b, a.y + b);
}
inline __device__ float2 operator-(float2 a, float b) {
    return make_float2(a.x - b, a.y - b);
}
Zhekai Zhang's avatar
Zhekai Zhang committed
156

Muyang Li's avatar
Muyang Li committed
157
static inline __device__ int8_t float_to_int8_rn(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
158
159
    uint32_t dst;
    asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
Muyang Li's avatar
Muyang Li committed
160
    return reinterpret_cast<const int8_t &>(dst);
Zhekai Zhang's avatar
Zhekai Zhang committed
161
162
163
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
164
inline __device__ T ldg(const T *val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
165
166
167
168
169
170
171
    return __ldg(val);
}

#if ENABLE_BF16
#define bf1622float2 __bfloat1622float2
#define float22bf162 __float22bfloat162_rn
#define bf162bf162 __bfloat162bfloat162
Muyang Li's avatar
Muyang Li committed
172
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
173
174
175
176
177
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float2 f_val;
    f_val.x = max(min(__low2float(val), 127.f), -128.f);
    f_val.y = max(min(__high2float(val), 127.f), -128.f);

Muyang Li's avatar
Muyang Li committed
178
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
179
180
181
182
183
184
185
186
187
188
189
        int8_t int8[2];
        int16_t int16;
    };

    int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
    int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
    return int16;
#else
    val = __hmin2(val, make_bfloat162(127., 127.));
    val = __hmax2(val, make_bfloat162(-128., -128.));

Muyang Li's avatar
Muyang Li committed
190
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
191
192
193
194
195
196
197
198
199
200
201
202
203
        int8_t int8[2];
        int16_t int16;
    };

    int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
    int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
    return int16;
#endif
}
#endif

#if ENABLE_BF16
template<>
Muyang Li's avatar
Muyang Li committed
204
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
205
206
207
208
209
210
211
212
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return val[0];
#else
    return __ldg(val);
#endif
}

template<>
Muyang Li's avatar
Muyang Li committed
213
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16 *val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
214
215
216
217
218
219
220
221
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return val[0];
#else
    return __ldg(val);
#endif
}
#endif // ENABLE_BF16

Muyang Li's avatar
Muyang Li committed
222
223
template<typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
224
225
226
    return val;
}

Muyang Li's avatar
Muyang Li committed
227
228
template<>
__device__ inline float2 cuda_cast<float2, int2>(int2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
229
230
231
    return make_float2(val.x, val.y);
}

Muyang Li's avatar
Muyang Li committed
232
233
template<>
__device__ inline float2 cuda_cast<float2, float>(float val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
234
235
236
    return make_float2(val, val);
}

Muyang Li's avatar
Muyang Li committed
237
238
template<>
__device__ inline float2 cuda_cast<float2, half2>(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
239
240
241
    return __half22float2(val);
}

Muyang Li's avatar
Muyang Li committed
242
243
template<>
__device__ inline half2 cuda_cast<half2, float2>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
244
245
246
    return __float22half2_rn(val);
}

Muyang Li's avatar
Muyang Li committed
247
248
template<>
__device__ inline half2 cuda_cast<half2, float>(float val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
249
250
251
    return __float2half2_rn(val);
}

Muyang Li's avatar
Muyang Li committed
252
253
template<>
__device__ inline half2 cuda_cast<half2, half>(half val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
254
255
256
    return __half2half2(val);
}

Muyang Li's avatar
Muyang Li committed
257
258
259
template<>
__device__ inline int8_t cuda_cast<int8_t, half>(half val) {
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
260
261
262
263
        int8_t int8[2];
        int16_t int16;
    };

Muyang Li's avatar
Muyang Li committed
264
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
265
266
267
268
269
270
271
272
273
        half fp16;
        int16_t int16_in;
    };

    fp16 = val;
    asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
    return int8[0];
}

Muyang Li's avatar
Muyang Li committed
274
275
276
template<>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) {
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
277
278
279
280
281
282
283
284
285
        int8_t int8[2];
        int16_t int16;
    };

    int8[0] = cuda_cast<int8_t>(val.x);
    int8[1] = cuda_cast<int8_t>(val.y);
    return int16;
}

Muyang Li's avatar
Muyang Li committed
286
287
288
template<>
__device__ inline int8_t cuda_cast<int8_t, float>(float val) {
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
289
290
291
292
293
294
295
296
        int8_t int8[2];
        int16_t int16;
    };

    asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
    return int8[0];
}

Muyang Li's avatar
Muyang Li committed
297
298
299
template<>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val) {
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
300
301
302
303
304
305
306
307
308
        int8_t int8[2];
        int16_t int16;
    };

    int8[0] = cuda_cast<int8_t>(val.x);
    int8[1] = cuda_cast<int8_t>(val.y);
    return int16;
}

Muyang Li's avatar
Muyang Li committed
309
310
311
template<>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val) {
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
312
313
314
315
316
317
318
319
        int8_t int8[2];
        int16_t int16;
    };

    int16 = val;
    return make_half2(int8[0], int8[1]);
}

Muyang Li's avatar
Muyang Li committed
320
321
322
template<>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val) {
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
323
324
325
326
327
328
329
330
331
        int8_t int8[2];
        int16_t int16;
    };

    int16 = val;
    return make_float2(int8[0], int8[1]);
}

#ifdef ENABLE_BF16
Muyang Li's avatar
Muyang Li committed
332
333
template<>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
334
335
336
    return static_cast<float>(val);
}

Muyang Li's avatar
Muyang Li committed
337
338
template<>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
339
340
341
    return static_cast<float>(val);
}

Muyang Li's avatar
Muyang Li committed
342
343
template<>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
344
345
346
    return static_cast<float>(val);
}

Muyang Li's avatar
Muyang Li committed
347
348
template<>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
349
350
351
    return __bfloat162float(val);
}

Muyang Li's avatar
Muyang Li committed
352
353
template<>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
354
355
356
    return bf1622float2(val);
}

Muyang Li's avatar
Muyang Li committed
357
358
template<>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
359
360
361
    return __float2half(__bfloat162float(val));
}

Muyang Li's avatar
Muyang Li committed
362
363
template<>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
364
365
366
    return bf1622int16(val);
}

Muyang Li's avatar
Muyang Li committed
367
368
template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
369
370
371
    return __float2bfloat16(val);
}

Muyang Li's avatar
Muyang Li committed
372
373
template<>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
374
375
376
    return __float2bfloat16(__half2float(val));
}

Muyang Li's avatar
Muyang Li committed
377
378
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
379
380
381
    return bf162bf162(val);
}

Muyang Li's avatar
Muyang Li committed
382
383
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
384
385
386
    return __float2bfloat162_rn(val);
}

Muyang Li's avatar
Muyang Li committed
387
388
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
389
390
391
    return float22bf162(val);
}

Muyang Li's avatar
Muyang Li committed
392
393
394
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) {
    union {
Zhekai Zhang's avatar
Zhekai Zhang committed
395
396
397
398
399
400
401
402
403
404
405
        int8_t int8[2];
        int16_t int16;
    };

    int16 = val;
    __nv_bfloat162 res;
    res.x = cuda_cast<__nv_bfloat16>(int8[0]);
    res.y = cuda_cast<__nv_bfloat16>(int8[1]);
    return res;
}

Muyang Li's avatar
Muyang Li committed
406
407
template<>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
408
409
410
411
412
    return float22bf162(__half22float2(val));
}

#endif // ENABLE BF16

Muyang Li's avatar
Muyang Li committed
413
414
template<typename f16_t>
__device__ __forceinline__ packed_as<f16_t, 2>::type f162f162(f16_t x);
Samuel Tesfai's avatar
Samuel Tesfai committed
415

Muyang Li's avatar
Muyang Li committed
416
417
418
template<>
__device__ __forceinline__ packed_as<half, 2>::type f162f162<half>(half x) {
    return __half2half2(x);
Samuel Tesfai's avatar
Samuel Tesfai committed
419
420
421
}

#ifdef ENABLE_BF16
Muyang Li's avatar
Muyang Li committed
422
423
424
425
426
427
428
429
template<>
__device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type f162f162<__nv_bfloat16>(__nv_bfloat16 x) {
    return __bfloat162bfloat162(x);
}
#endif

template<typename To, typename Ti>
__device__ inline To cuda_sum(Ti val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
430
431
432
    return cuda_cast<To>(val);
};

Muyang Li's avatar
Muyang Li committed
433
434
template<typename To>
__device__ inline To cuda_sum(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
435
436
437
438
    return cuda_cast<To>(val.x + val.y);
};

// Unary maximum: compute the max of a vector type
Muyang Li's avatar
Muyang Li committed
439
440
template<typename To, typename Ti>
__device__ inline To cuda_max(Ti val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
441
442
443
    return cuda_cast<To>(val);
};

Muyang Li's avatar
Muyang Li committed
444
445
template<>
__device__ inline float cuda_max(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
446
447
448
    return fmaxf(val.x, val.y);
}

Muyang Li's avatar
Muyang Li committed
449
450
template<>
__device__ inline half cuda_max(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
451
452
453
454
    return __hmax(val.x, val.y);
}

#ifdef ENABLE_BF16
Muyang Li's avatar
Muyang Li committed
455
456
template<>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
457
458
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
    return __hmax(val.x, val.y);
459
460
461
#else
    assert(false);
    return 0;
Zhekai Zhang's avatar
Zhekai Zhang committed
462
463
464
465
466
#endif
}
#endif

// Binary maximum: compute the max of two scalar types
Muyang Li's avatar
Muyang Li committed
467
468
template<typename T>
__device__ inline T cuda_max(T val1, T val2) {
Zhekai Zhang's avatar
Zhekai Zhang committed
469
470
471
    return (val1 > val2) ? val1 : val2;
}

Muyang Li's avatar
Muyang Li committed
472
473
template<typename T>
__device__ inline T cuda_abs(T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
474
475
476
477
    assert(false);
    return {};
}

Muyang Li's avatar
Muyang Li committed
478
479
template<>
__device__ inline float cuda_abs(float val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
480
481
482
    return fabs(val);
}

Muyang Li's avatar
Muyang Li committed
483
484
template<>
__device__ inline float2 cuda_abs(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
485
486
487
    return make_float2(fabs(val.x), fabs(val.y));
}

Muyang Li's avatar
Muyang Li committed
488
489
template<>
__device__ inline half cuda_abs(half val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
490
491
492
    return __habs(val);
}

Muyang Li's avatar
Muyang Li committed
493
494
template<>
__device__ inline half2 cuda_abs(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
495
496
497
498
499
500
    return __habs2(val);
}

#ifdef ENABLE_BF16

#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
Muyang Li's avatar
Muyang Li committed
501
502
template<>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
503
504
505
    return __habs(val);
}

Muyang Li's avatar
Muyang Li committed
506
507
template<>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
508
509
510
511
    return __habs2(val);
}
#endif

Muyang Li's avatar
Muyang Li committed
512
#endif // ENABLE_FP16