"vscode:/vscode.git/clone" did not exist on "cba1cdbc46013dd34ab14e9e04ef5adec7c8d5d3"
cuda_bf16_fallbacks.cuh 7.98 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <cuda_fp16.h>

namespace fastertransformer {

#ifdef ENABLE_BF16
AllentDan's avatar
AllentDan committed
25
26
inline __device__ float2 bf1622float2(const __nv_bfloat162 val)
{
Li Zhang's avatar
Li Zhang committed
27
28
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float2 f_val;
AllentDan's avatar
AllentDan committed
29
    f_val.x = __low2float(val);
Li Zhang's avatar
Li Zhang committed
30
31
32
33
34
35
36
    f_val.y = __high2float(val);
    return f_val;
#else
    return __bfloat1622float2(val);
#endif
}

AllentDan's avatar
AllentDan committed
37
38
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
Li Zhang's avatar
Li Zhang committed
39
40
41
42
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float2 f_val;
    f_val.x = max(min(__low2float(val), 127.f), -128.f);
    f_val.y = max(min(__high2float(val), 127.f), -128.f);
AllentDan's avatar
AllentDan committed
43
44
45
46
    union {
        int8_t  int8[2];
        int16_t int16;
    };
Li Zhang's avatar
Li Zhang committed
47
48
49
50
51
52
    int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
    int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
    return int16;
#else
    val = __hmin2(val, make_bfloat162(127., 127.));
    val = __hmax2(val, make_bfloat162(-128., -128.));
AllentDan's avatar
AllentDan committed
53
54
55
56
    union {
        int8_t  int8[2];
        int16_t int16;
    };
Li Zhang's avatar
Li Zhang committed
57
58
59
60
61
62
    int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
    int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
    return int16;
#endif
}

AllentDan's avatar
AllentDan committed
63
64
inline __device__ __nv_bfloat162 float22bf162(const float2 val)
{
Li Zhang's avatar
Li Zhang committed
65
66
67
68
69
70
71
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __floats2bfloat162_rn(val.x, val.y);
#else
    return __float22bfloat162_rn(val);
#endif
}

AllentDan's avatar
AllentDan committed
72
73
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val)
{
Li Zhang's avatar
Li Zhang committed
74
75
76
77
78
79
80
81
82
83
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    __nv_bfloat162 val2;
    val2.x = val;
    val2.y = val;
    return val2;
#else
    return __bfloat162bfloat162(val);
#endif
}

AllentDan's avatar
AllentDan committed
84
85
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
Li Zhang's avatar
Li Zhang committed
86
87
88
89
90
91
92
93
94
95
96
97
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh, fyl, fyh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    fyl = __low2float(y);
    fyh = __high2float(y);
    return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
    return __hadd2(x, y);
#endif
}

AllentDan's avatar
AllentDan committed
98
99
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
Li Zhang's avatar
Li Zhang committed
100
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
AllentDan's avatar
AllentDan committed
101
    return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y));
Li Zhang's avatar
Li Zhang committed
102
103
104
105
106
#else
    return __hadd(x, y);
#endif
}

AllentDan's avatar
AllentDan committed
107
108
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
Li Zhang's avatar
Li Zhang committed
109
110
111
112
113
114
115
116
117
118
119
120
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh, fyl, fyh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    fyl = __low2float(y);
    fyh = __high2float(y);
    return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
    return __hsub2(x, y);
#endif
}

AllentDan's avatar
AllentDan committed
121
122
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
Li Zhang's avatar
Li Zhang committed
123
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
AllentDan's avatar
AllentDan committed
124
    return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y));
Li Zhang's avatar
Li Zhang committed
125
126
127
128
129
#else
    return __hsub(x, y);
#endif
}

AllentDan's avatar
AllentDan committed
130
131
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
Li Zhang's avatar
Li Zhang committed
132
133
134
135
136
137
138
139
140
141
142
143
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh, fyl, fyh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    fyl = __low2float(y);
    fyh = __high2float(y);
    return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
    return __hmul2(x, y);
#endif
}

AllentDan's avatar
AllentDan committed
144
145
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
Li Zhang's avatar
Li Zhang committed
146
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
AllentDan's avatar
AllentDan committed
147
148
    return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y));
#else
Li Zhang's avatar
Li Zhang committed
149
150
151
152
    return __hmul(x, y);
#endif
}

AllentDan's avatar
AllentDan committed
153
154
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z)
{
Li Zhang's avatar
Li Zhang committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh, fyl, fyh, fzl, fzh;
    fxl = __low2float(x);
    fxh = __high2float(x);
    fyl = __low2float(y);
    fyh = __high2float(y);
    fzl = __low2float(z);
    fzh = __high2float(z);
    return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
    return __hfma2(x, y, z);
#endif
}

AllentDan's avatar
AllentDan committed
169
170
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z)
{
Li Zhang's avatar
Li Zhang committed
171
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
AllentDan's avatar
AllentDan committed
172
    return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
Li Zhang's avatar
Li Zhang committed
173
174
175
176
177
#else
    return __hfma(x, y, z);
#endif
}

AllentDan's avatar
AllentDan committed
178
179
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
{
Li Zhang's avatar
Li Zhang committed
180
181
182
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fxl, fxh;
    fxl = __low2float(x);
AllentDan's avatar
AllentDan committed
183
184
    fxh = __high2float(x);
    ;
Li Zhang's avatar
Li Zhang committed
185
186
187
188
189
190
191
    return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
    return h2exp(x);
#endif
}

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
AllentDan's avatar
AllentDan committed
192
193
194
195
196
197
198
199
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
    return bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
    return bf16hadd2(x, y);
};
Li Zhang's avatar
Li Zhang committed
200
201
202

inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
AllentDan's avatar
AllentDan committed
203
204
205
206
    __nv_bfloat162 t;
    t.x = x;
    t.y = y;
    return t;
Li Zhang's avatar
Li Zhang committed
207
208
209
210
}

#endif

AllentDan's avatar
AllentDan committed
211
212
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
Li Zhang's avatar
Li Zhang committed
213
214
215
216
217
218
219
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
    return a + b + c;
#endif
}

AllentDan's avatar
AllentDan committed
220
221
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d)
{
Li Zhang's avatar
Li Zhang committed
222
223
224
225
226
227
228
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
    return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
#endif
}

AllentDan's avatar
AllentDan committed
229
230
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
Li Zhang's avatar
Li Zhang committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fal, fah, fbl, fbh, fcl, fch;
    fal = __low2float(a);
    fah = __high2float(a);
    fbl = __low2float(b);
    fbh = __high2float(b);
    fcl = __low2float(c);
    fch = __high2float(c);
    return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
    return a + b + c;
#endif
}

AllentDan's avatar
AllentDan committed
245
246
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c)
{
Li Zhang's avatar
Li Zhang committed
247
248
249
250
251
252
253
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
    return a * b * c;
#endif
}

AllentDan's avatar
AllentDan committed
254
255
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
Li Zhang's avatar
Li Zhang committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fal, fah, fbl, fbh, fcl, fch;
    fal = __low2float(a);
    fah = __high2float(a);
    fbl = __low2float(b);
    fbh = __high2float(b);
    fcl = __low2float(c);
    fch = __high2float(c);
    return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
    return a * b * c;
#endif
}

AllentDan's avatar
AllentDan committed
270
271
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d)
{
Li Zhang's avatar
Li Zhang committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
    float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
    fal = __low2float(a);
    fah = __high2float(a);
    fbl = __low2float(b);
    fbh = __high2float(b);
    fcl = __low2float(c);
    fch = __high2float(c);
    fdl = __low2float(d);
    fdh = __high2float(d);
    return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
    return a * b * c + d;
#endif
}

AllentDan's avatar
AllentDan committed
288
#endif  // ENABLE_BF16
Li Zhang's avatar
Li Zhang committed
289

AllentDan's avatar
AllentDan committed
290
}  // namespace fastertransformer