gemm_utils.cuh 17.4 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#pragma once

#include <cstdint>
#include "common.h"
muyangli's avatar
muyangli committed
5
6
7
#include "../utils.cuh"

namespace nunchaku::kernels {
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9

static constexpr int clamp(int val, int min, int max) {
Muyang Li's avatar
Muyang Li committed
10
    if (val < min)
Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
14
15
16
17
        return min;
    if (val > max)
        return max;
    return val;
}

template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
18
__device__ __forceinline__ static T load(const T *addr) {
Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data;
Muyang Li's avatar
Muyang Li committed
22
23
            asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];"
                         : "=r"(data.x), "=r"(data.y)
fengzch's avatar
fengzch committed
24
                         : "l"((addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
28
            return *reinterpret_cast<T *>(&data);
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data;
Muyang Li's avatar
Muyang Li committed
29
30
            asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
                         : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
fengzch's avatar
fengzch committed
31
                         : "l"((addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            return *reinterpret_cast<T *>(&data);
        }
        return *addr;
    }

    if constexpr (sizeof(T) == 8) {
        uint2 data = __ldg(reinterpret_cast<const uint2 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = __ldg(reinterpret_cast<const uint4 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }

    return *addr;
}

fengzch's avatar
fengzch committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
// template<typename T>
// __device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
//     if constexpr (sizeof(T) == 4) {
//         uint32_t data;
//         // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
//         //              "@loadpred ld.global.nc.b32 %0, [%1];"
//         //              "}"
//         //              : "=r"(data)
//         //              : "l"(addr), "r"((int)pred));
//         return *reinterpret_cast<T *>(&data);
//     }
//     if constexpr (sizeof(T) == 8) {
//         uint2 data;
//         // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
//         //              "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
//         //              "}"
//         //              : "=r"(data.x), "=r"(data.y)
//         //              : "l"(addr), "r"((int)pred));
//         return *reinterpret_cast<T *>(&data);
//     }
//     if constexpr (sizeof(T) == 16) {
//         uint4 data;
//         // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
//         //              "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
//         //              "}"
//         //              : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
//         //              : "l"(addr), "r"((int)pred));
//         return *reinterpret_cast<T *>(&data);
//     }

//     T result;
//     if (pred) {
//         result = *addr;
//     }
//     return result;
// }

86
template<typename T>
Muyang Li's avatar
Muyang Li committed
87
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
88
89
    if constexpr (sizeof(T) == 4) {
        uint32_t data;
fengzch's avatar
fengzch committed
90
91
92
93
94
95
        if (pred) {
            const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
            unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
            #pragma unroll
            for (int i = 0; i < 4; ++i) dst[i] = src[i];
        }
96
97
98
99
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data;
fengzch's avatar
fengzch committed
100
101
102
103
104
105
        if (pred) {
            const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
            unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
            #pragma unroll
            for (int i = 0; i < 8; ++i) dst[i] = src[i];
        }
106
107
108
109
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data;
fengzch's avatar
fengzch committed
110
111
112
113
114
115
        if (pred) {
            const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
            unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
            #pragma unroll
            for (int i = 0; i < 16; ++i) dst[i] = src[i];
        }
116
117
118
119
120
121
122
123
124
125
        return *reinterpret_cast<T *>(&data);
    }

    T result;
    if (pred) {
        result = *addr;
    }
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
126
template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
127
__device__ __forceinline__ static void store(T *addr, T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
128
129
130
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data = *reinterpret_cast<uint2 *>(&val);
fengzch's avatar
fengzch committed
131
132
            asm volatile(
                "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
Zhekai Zhang's avatar
Zhekai Zhang committed
133
134
135
136
            return;
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data = *reinterpret_cast<uint4 *>(&val);
fengzch's avatar
fengzch committed
137
138
139
140
141
            asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
                         "r"(data.x),
                         "r"(data.y),
                         "r"(data.z),
                         "r"(data.w));
Zhekai Zhang's avatar
Zhekai Zhang committed
142
143
144
145
146
147
148
            return;
        }
        *addr = val;
        return;
    }

    if constexpr (sizeof(T) == 4) {
limm's avatar
limm committed
149
        // __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
fengzch's avatar
fengzch committed
150
	    *reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
151
152
153
        return;
    }
    if constexpr (sizeof(T) == 8) {
limm's avatar
limm committed
154
        // __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
fengzch's avatar
fengzch committed
155
	    *reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
156
157
158
        return;
    }
    if constexpr (sizeof(T) == 16) {
limm's avatar
limm committed
159
        // __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
fengzch's avatar
fengzch committed
160
	    *reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
161
        return;
Muyang Li's avatar
Muyang Li committed
162
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
163
164
165
    *addr = val;
}

166
template<typename T>
Muyang Li's avatar
Muyang Li committed
167
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
168
169
    if constexpr (sizeof(T) == 4) {
        uint32_t data = *reinterpret_cast<uint32_t *>(&val);
fengzch's avatar
fengzch committed
170
171
172
173
174
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.b32 [%1], %2;"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data));
175
176
177
178
        return;
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data = *reinterpret_cast<uint2 *>(&val);
fengzch's avatar
fengzch committed
179
180
181
182
183
184
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y));
185
186
187
188
        return;
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = *reinterpret_cast<uint4 *>(&val);
fengzch's avatar
fengzch committed
189
190
191
192
193
194
195
196
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y),
                     "r"(data.z),
                     "r"(data.w));
197
198
199
200
201
202
203
204
        return;
    }

    if (pred) {
        *addr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
205
__device__ __forceinline__ static float2 half22float2(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
206
207
208
    return __half22float2(val);
}

fengzch-das's avatar
fengzch-das committed
209
__device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
210
211
212
213
    return __bfloat1622float2(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
214
__device__ __forceinline__ static T float22half2(float2 val) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
215
216

template<>
Muyang Li's avatar
Muyang Li committed
217
__device__ __forceinline__ half2 float22half2<half2>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
218
219
220
221
    return __float22half2_rn(val);
}

template<>
fengzch-das's avatar
fengzch-das committed
222
__device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
223
224
225
226
    return __float22bfloat162_rn(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
227
__device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
Zhekai Zhang's avatar
Zhekai Zhang committed
228
229
230
231
232
233
    volatile T *ptr = nullptr;
    if (alwaysfalse) {
        *ptr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
234
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
fengzch's avatar
fengzch committed
235
236
    asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
                 : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
fengzch's avatar
fengzch committed
237
                 : "l"((ptr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
238
239
}

240
template<typename T>
Muyang Li's avatar
Muyang Li committed
241
__device__ __forceinline__ static T movmatrix(T x) {
fengzch's avatar
fengzch committed
242
243
244
245
    // asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
    //              : "=r"(*reinterpret_cast<uint32_t *>(&x))
    //              : "r"(*reinterpret_cast<uint32_t *>(&x)));

246
247
248
    return x;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
249
250
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
Muyang Li's avatar
Muyang Li committed
251
__device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
252
253

template<>
Muyang Li's avatar
Muyang Li committed
254
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
255
256
    int v1, v2;
    uint32_t result;
fengzch's avatar
fengzch committed
257
258
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
259
    // asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
fengzch's avatar
fengzch committed
260
261
262
263
    v1 = __float2int_rn(value.x);
    v2 = __float2int_rn(value.y);
    int s1 = max(-8, min(7, v1));
    int s2 = max(-8, min(7, v2));
fengzch's avatar
fengzch committed
264
    
fengzch's avatar
fengzch committed
265
266
267
    unsigned int u1 = s1 & 0xF;
    unsigned int u2 = s2 & 0xF;
    result = (u2 << 4) | u1;
Zhekai Zhang's avatar
Zhekai Zhang committed
268
269
270
271
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
272
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
273
274
    int v1, v2;
    uint32_t result;
fengzch's avatar
fengzch committed
275
276
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
277
    // asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
fengzch's avatar
fengzch committed
278
279
280
281
282
283

    v1 = __float2int_rn(value.x);
    v2 = __float2int_rn(value.y);
    unsigned int u1 = static_cast<unsigned int>(max(0, min(15, v1)));
    unsigned int u2 = static_cast<unsigned int>(max(0, min(15, v2)));
    result = (u2 << 4) | u1;
Zhekai Zhang's avatar
Zhekai Zhang committed
284
285
286
287
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
288
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
289
290
    int v1, v2;
    uint32_t result;
fengzch's avatar
fengzch committed
291
292
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
293
    // asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
fengzch's avatar
fengzch committed
294
295
    v1 = __float2int_rn(value.x);  // 等价于 roundf(value.x)
    v2 = __float2int_rn(value.y);
fengzch's avatar
fengzch committed
296
    
fengzch's avatar
fengzch committed
297
298
299
300
301
302
303
304
305
    // 第二步:饱和处理到8位有符号范围 [-128, 127]
    int s1 = max(-128, min(127, v1));
    int s2 = max(-128, min(127, v2));
    
    // 第三步:将有符号值转换为无符号位模式
    // 使用位运算将有符号数转换为8位二进制补码表示
    unsigned int u1 = s1 & 0xFF;  // 只取低8位
    unsigned int u2 = s2 & 0xFF;
    result = (u2 << 8) | u1;
Zhekai Zhang's avatar
Zhekai Zhang committed
306
307
308
    return result;
}

Muyang Li's avatar
Muyang Li committed
309
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
310
    uint32_t result;
fengzch's avatar
fengzch committed
311
312
313
    asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
                 : "=r"(result)
                 : "f"(value.y), "f"(value.x));
314
315
316
    return result;
}

Muyang Li's avatar
Muyang Li committed
317
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
318
    uint16_t lo, hi;
fengzch's avatar
fengzch committed
319
320
    // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
    // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
321
322
323
    return uint32_t(lo) | (uint32_t(hi) << 16);
}

Muyang Li's avatar
Muyang Li committed
324
__device__ __forceinline__ static float cuda_tanhf(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
325
    float result;
Muyang Li's avatar
Muyang Li committed
326
    asm("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
327
328
329
    return result;
}

Muyang Li's avatar
Muyang Li committed
330
__device__ __forceinline__ static float cuda_frcp(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
331
    float result;
Muyang Li's avatar
Muyang Li committed
332
    asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
333
334
335
    return result;
}

Muyang Li's avatar
Muyang Li committed
336
__device__ __forceinline__ static float cuda_frsqrt(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
337
    float result;
Muyang Li's avatar
Muyang Li committed
338
    asm("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
339
340
341
    return result;
}

Muyang Li's avatar
Muyang Li committed
342
__device__ __forceinline__ static float cuda_sin(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
343
    float result;
Muyang Li's avatar
Muyang Li committed
344
    asm("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
345
346
347
    return result;
}

Muyang Li's avatar
Muyang Li committed
348
__device__ __forceinline__ static float cuda_cos(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
349
    float result;
Muyang Li's avatar
Muyang Li committed
350
    asm("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
351
352
353
    return result;
}

Muyang Li's avatar
Muyang Li committed
354
__device__ __forceinline__ static float cuda_exp2(float x) {
355
    float result;
Muyang Li's avatar
Muyang Li committed
356
    asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
357
358
359
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
360
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
Muyang Li's avatar
Muyang Li committed
361
__forceinline__ __device__ static float cuda_sigmoidf(float a) {
Zhekai Zhang's avatar
Zhekai Zhang committed
362
#if USE_TANH
Muyang Li's avatar
Muyang Li committed
363
364
    return fmaf(0.5, __tanhf(0.5f * a), 0.5f);
#else  // USE_TANH
Zhekai Zhang's avatar
Zhekai Zhang committed
365
366
367
    const float L2E = 1.442695041f; // log2(exp(1))
    float t, d, e, r;
    t = -L2E * a;
Muyang Li's avatar
Muyang Li committed
368
    asm("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
Zhekai Zhang's avatar
Zhekai Zhang committed
369
    d = e + 1.0f;
Muyang Li's avatar
Muyang Li committed
370
    asm("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
Zhekai Zhang's avatar
Zhekai Zhang committed
371
372
373
374
375
    return r;
#endif // USE_TANH
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
376
__device__ __forceinline__ static T gelu_half2(T x) {
muyangli's avatar
muyangli committed
377
    float2 xf  = half22float2(x);
Zhekai Zhang's avatar
Zhekai Zhang committed
378
    float2 x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
379
380
    float t1   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
    float t2   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
Zhekai Zhang's avatar
Zhekai Zhang committed
381
382
383
384
    return float22half2<T>(xf * make_float2(t1, t2));
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
385
__device__ __forceinline__ static T gelu_half(T x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
386
387
    float xf  = float(x);
    float x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
388
    float t   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
Zhekai Zhang's avatar
Zhekai Zhang committed
389
390
391
    return (T)(xf * t);
}

Muyang Li's avatar
Muyang Li committed
392
393
394
395
396
template<typename T>
__device__ __forceinline__ static T silu(const T &x) {
    // x * sigmoid(x)
    return (T)((float)x * cuda_sigmoidf((float)x));
    // return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
Zhekai Zhang's avatar
Zhekai Zhang committed
397
398
}

Muyang Li's avatar
Muyang Li committed
399
__device__ __forceinline__ static half2 h2div(half2 a, half2 b) {
muyangli's avatar
muyangli committed
400
401
402
403
404
405
406
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<half2>(of);
};
fengzch-das's avatar
fengzch-das committed
407
__device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
muyangli's avatar
muyangli committed
408
409
410
411
412
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
fengzch-das's avatar
fengzch-das committed
413
    return float22half2<__nv_bfloat162>(of);
muyangli's avatar
muyangli committed
414
415
};

Muyang Li's avatar
Muyang Li committed
416
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
fengzch's avatar
fengzch committed
417
    asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
Zhekai Zhang's avatar
Zhekai Zhang committed
418
419
}

Muyang Li's avatar
Muyang Li committed
420
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
fengzch's avatar
fengzch committed
421
422
423
424
425
    asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                 "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
                 "}" ::"r"((int)pred),
                 "l"(addr),
                 "f"(val));
sxtyzhangzk's avatar
sxtyzhangzk committed
426
427
}

Zhekai Zhang's avatar
Zhekai Zhang committed
428
template<int cnt, typename F>
Muyang Li's avatar
Muyang Li committed
429
430
__device__ __forceinline__ static void unrolled_loop(F &&lambda) {
    auto call = [&]<int... Is>(std::integer_sequence<int, Is...>) { (lambda.template operator()<Is>(), ...); };
Zhekai Zhang's avatar
Zhekai Zhang committed
431
    call(std::make_integer_sequence<int, cnt>());
muyangli's avatar
muyangli committed
432
433
}

434
435
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
Muyang Li's avatar
Muyang Li committed
436
__device__ __forceinline__ static float int2float_fast(int val) {
fengzch's avatar
fengzch committed
437
438
    // float fval;
    // // fval = (val & 0x7FFFFF) ^ 0x4B400000
fengzch's avatar
fengzch committed
439
440
441
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
    //              : "=f"(fval)
    //              : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
fengzch's avatar
fengzch committed
442
443
444
445
    unsigned int temp = (val & 0x7FFFFF) ^ 0x4B400000;
    float result;
    memcpy(&result, &temp, sizeof(float));
    return result - 12582912.0f;
446
447
}

448
template<typename To, typename From>
Muyang Li's avatar
Muyang Li committed
449
__device__ __forceinline__ static To bit_cast(const From &input) {
450
451
452
453
454
    static_assert(sizeof(To) == sizeof(From));
    // not safe but anyway
    return *reinterpret_cast<const To *>(&input);
}

455
456
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
Muyang Li's avatar
Muyang Li committed
457
__device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
458
459
460
    uint32_t ival;
    uint32_t hval;
    // ival.lo = x.lo; ival.hi = y.lo;
fengzch's avatar
fengzch committed
461
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
462
463
    ival = ival >> 4;
    // (val & 0x03FF03FF) ^ 0x76007600
fengzch's avatar
fengzch committed
464
465
466
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
467
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
468
469
}
// val in [-4096, 4095], steps of 8, round to nearest
Muyang Li's avatar
Muyang Li committed
470
__device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
471
472
473
474
475
476
477
    // x = max(min(x, 4095), -4096);
    // y = max(min(y, 4095), -4096);
    // TODO: round to even?
    x = x * 8192 + 32768;
    y = y * 8192 + 32768;
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
478
    // ival.lo = x.hi; ival.hi = y.hi;
479
    // <=> divide x and y by 65536 and pack them
fengzch's avatar
fengzch committed
480
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
481
    // (val & 0x03FF03FF) ^ 0x72007200
fengzch's avatar
fengzch committed
482
483
484
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
485
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
486
487
}
// val in [-512, 511]
Muyang Li's avatar
Muyang Li committed
488
__device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
489
490
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
491
    // ival.lo = x.lo; ival.hi = y.lo;
492
    // <=> divide x and y by 65536 and pack them
fengzch's avatar
fengzch committed
493
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
494
    // (val & 0x03FF03FF) ^ 0x66006600
fengzch's avatar
fengzch committed
495
496
497
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
498
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
499
500
}

Muyang Li's avatar
Muyang Li committed
501
}; // namespace nunchaku::kernels