gemm_utils.cuh 17.6 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#pragma once

#include <cstdint>
#include "common.h"
muyangli's avatar
muyangli committed
5
6
7
#include "../utils.cuh"

namespace nunchaku::kernels {
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9

static constexpr int clamp(int val, int min, int max) {
Muyang Li's avatar
Muyang Li committed
10
    if (val < min)
Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
14
15
16
17
        return min;
    if (val > max)
        return max;
    return val;
}

template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
18
__device__ __forceinline__ static T load(const T *addr) {
Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data;
Muyang Li's avatar
Muyang Li committed
22
23
            asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];"
                         : "=r"(data.x), "=r"(data.y)
fengzch's avatar
fengzch committed
24
                         : "l"((addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
28
            return *reinterpret_cast<T *>(&data);
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data;
Muyang Li's avatar
Muyang Li committed
29
30
            asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
                         : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
fengzch's avatar
fengzch committed
31
                         : "l"((addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            return *reinterpret_cast<T *>(&data);
        }
        return *addr;
    }

    if constexpr (sizeof(T) == 8) {
        uint2 data = __ldg(reinterpret_cast<const uint2 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = __ldg(reinterpret_cast<const uint4 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }

    return *addr;
}

fengzch's avatar
fengzch committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
// template<typename T>
// __device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
//     if constexpr (sizeof(T) == 4) {
//         uint32_t data;
//         // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
//         //              "@loadpred ld.global.nc.b32 %0, [%1];"
//         //              "}"
//         //              : "=r"(data)
//         //              : "l"(addr), "r"((int)pred));
//         return *reinterpret_cast<T *>(&data);
//     }
//     if constexpr (sizeof(T) == 8) {
//         uint2 data;
//         // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
//         //              "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
//         //              "}"
//         //              : "=r"(data.x), "=r"(data.y)
//         //              : "l"(addr), "r"((int)pred));
//         return *reinterpret_cast<T *>(&data);
//     }
//     if constexpr (sizeof(T) == 16) {
//         uint4 data;
//         // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
//         //              "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
//         //              "}"
//         //              : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
//         //              : "l"(addr), "r"((int)pred));
//         return *reinterpret_cast<T *>(&data);
//     }

//     T result;
//     if (pred) {
//         result = *addr;
//     }
//     return result;
// }

86
template<typename T>
Muyang Li's avatar
Muyang Li committed
87
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
88
89
    if constexpr (sizeof(T) == 4) {
        uint32_t data;
fengzch's avatar
fengzch committed
90
91
92
93
94
95
        if (pred) {
            const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
            unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
            #pragma unroll
            for (int i = 0; i < 4; ++i) dst[i] = src[i];
        }
96
97
98
99
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data;
fengzch's avatar
fengzch committed
100
101
102
103
104
105
        if (pred) {
            const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
            unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
            #pragma unroll
            for (int i = 0; i < 8; ++i) dst[i] = src[i];
        }
106
107
108
109
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data;
fengzch's avatar
fengzch committed
110
111
112
113
114
115
        if (pred) {
            const unsigned char *src = reinterpret_cast<const unsigned char *>(addr);
            unsigned char *dst = reinterpret_cast<unsigned char *>(&data);
            #pragma unroll
            for (int i = 0; i < 16; ++i) dst[i] = src[i];
        }
116
117
118
119
120
121
122
123
124
125
        return *reinterpret_cast<T *>(&data);
    }

    T result;
    if (pred) {
        result = *addr;
    }
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
126
template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
127
__device__ __forceinline__ static void store(T *addr, T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
128
129
130
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data = *reinterpret_cast<uint2 *>(&val);
fengzch's avatar
fengzch committed
131
132
            asm volatile(
                "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
Zhekai Zhang's avatar
Zhekai Zhang committed
133
134
135
136
            return;
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data = *reinterpret_cast<uint4 *>(&val);
fengzch's avatar
fengzch committed
137
138
139
140
141
            asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
                         "r"(data.x),
                         "r"(data.y),
                         "r"(data.z),
                         "r"(data.w));
Zhekai Zhang's avatar
Zhekai Zhang committed
142
143
144
145
146
147
148
            return;
        }
        *addr = val;
        return;
    }

    if constexpr (sizeof(T) == 4) {
limm's avatar
limm committed
149
        // __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
fengzch's avatar
fengzch committed
150
	    *reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
151
152
153
        return;
    }
    if constexpr (sizeof(T) == 8) {
limm's avatar
limm committed
154
        // __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
fengzch's avatar
fengzch committed
155
	    *reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
156
157
158
        return;
    }
    if constexpr (sizeof(T) == 16) {
limm's avatar
limm committed
159
        // __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
fengzch's avatar
fengzch committed
160
	    *reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
161
        return;
Muyang Li's avatar
Muyang Li committed
162
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
163
164
165
    *addr = val;
}

166
template<typename T>
Muyang Li's avatar
Muyang Li committed
167
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
168
169
    if constexpr (sizeof(T) == 4) {
        uint32_t data = *reinterpret_cast<uint32_t *>(&val);
fengzch's avatar
fengzch committed
170
171
172
173
174
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.b32 [%1], %2;"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data));
175
176
177
178
        return;
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data = *reinterpret_cast<uint2 *>(&val);
fengzch's avatar
fengzch committed
179
180
181
182
183
184
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y));
185
186
187
188
        return;
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = *reinterpret_cast<uint4 *>(&val);
fengzch's avatar
fengzch committed
189
190
191
192
193
194
195
196
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y),
                     "r"(data.z),
                     "r"(data.w));
197
198
199
200
201
202
203
204
        return;
    }

    if (pred) {
        *addr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
205
__device__ __forceinline__ static float2 half22float2(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
206
207
208
    return __half22float2(val);
}

fengzch-das's avatar
fengzch-das committed
209
__device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
210
211
212
213
    return __bfloat1622float2(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
214
__device__ __forceinline__ static T float22half2(float2 val) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
215
216

template<>
Muyang Li's avatar
Muyang Li committed
217
__device__ __forceinline__ half2 float22half2<half2>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
218
219
220
221
    return __float22half2_rn(val);
}

template<>
fengzch-das's avatar
fengzch-das committed
222
__device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
223
224
225
226
    return __float22bfloat162_rn(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
227
__device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
Zhekai Zhang's avatar
Zhekai Zhang committed
228
229
230
231
232
233
    volatile T *ptr = nullptr;
    if (alwaysfalse) {
        *ptr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
234
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
fengzch's avatar
fengzch committed
235
236
    asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
                 : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
fengzch's avatar
fengzch committed
237
                 : "l"((ptr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
238
239
}

240
template<typename T>
Muyang Li's avatar
Muyang Li committed
241
__device__ __forceinline__ static T movmatrix(T x) {
fengzch's avatar
fengzch committed
242
243
244
245
246
    // asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
    //              : "=r"(*reinterpret_cast<uint32_t *>(&x))
    //              : "r"(*reinterpret_cast<uint32_t *>(&x)));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);

247
248
249
    return x;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
250
251
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
Muyang Li's avatar
Muyang Li committed
252
__device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
253
254

template<>
Muyang Li's avatar
Muyang Li committed
255
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
256
257
    int v1, v2;
    uint32_t result;
fengzch's avatar
fengzch committed
258
259
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
260
    // asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
fengzch's avatar
fengzch committed
261
262
263
264
    v1 = __float2int_rn(value.x);
    v2 = __float2int_rn(value.y);
    int s1 = max(-8, min(7, v1));
    int s2 = max(-8, min(7, v2));
fengzch's avatar
fengzch committed
265
    
fengzch's avatar
fengzch committed
266
267
268
    unsigned int u1 = s1 & 0xF;
    unsigned int u2 = s2 & 0xF;
    result = (u2 << 4) | u1;
Zhekai Zhang's avatar
Zhekai Zhang committed
269
270
271
272
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
273
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
274
275
    int v1, v2;
    uint32_t result;
fengzch's avatar
fengzch committed
276
277
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
278
    // asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
fengzch's avatar
fengzch committed
279
280
281
282
283
284

    v1 = __float2int_rn(value.x);
    v2 = __float2int_rn(value.y);
    unsigned int u1 = static_cast<unsigned int>(max(0, min(15, v1)));
    unsigned int u2 = static_cast<unsigned int>(max(0, min(15, v2)));
    result = (u2 << 4) | u1;
Zhekai Zhang's avatar
Zhekai Zhang committed
285
286
287
288
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
289
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
290
291
    int v1, v2;
    uint32_t result;
fengzch's avatar
fengzch committed
292
293
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
294
    // asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
fengzch's avatar
fengzch committed
295
296
    v1 = __float2int_rn(value.x);  // 等价于 roundf(value.x)
    v2 = __float2int_rn(value.y);
fengzch's avatar
fengzch committed
297
    
fengzch's avatar
fengzch committed
298
299
300
301
302
303
304
305
306
    // 第二步:饱和处理到8位有符号范围 [-128, 127]
    int s1 = max(-128, min(127, v1));
    int s2 = max(-128, min(127, v2));
    
    // 第三步:将有符号值转换为无符号位模式
    // 使用位运算将有符号数转换为8位二进制补码表示
    unsigned int u1 = s1 & 0xFF;  // 只取低8位
    unsigned int u2 = s2 & 0xFF;
    result = (u2 << 8) | u1;
Zhekai Zhang's avatar
Zhekai Zhang committed
307
308
309
    return result;
}

Muyang Li's avatar
Muyang Li committed
310
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
311
    uint32_t result;
fengzch's avatar
fengzch committed
312
313
314
    asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
                 : "=r"(result)
                 : "f"(value.y), "f"(value.x));
315
316
317
    return result;
}

Muyang Li's avatar
Muyang Li committed
318
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
319
    uint16_t lo, hi;
fengzch's avatar
fengzch committed
320
321
322
    // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
    // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
323
324
325
    return uint32_t(lo) | (uint32_t(hi) << 16);
}

Muyang Li's avatar
Muyang Li committed
326
__device__ __forceinline__ static float cuda_tanhf(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
327
    float result;
Muyang Li's avatar
Muyang Li committed
328
    asm("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
329
330
331
    return result;
}

Muyang Li's avatar
Muyang Li committed
332
__device__ __forceinline__ static float cuda_frcp(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
333
    float result;
Muyang Li's avatar
Muyang Li committed
334
    asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
335
336
337
    return result;
}

Muyang Li's avatar
Muyang Li committed
338
__device__ __forceinline__ static float cuda_frsqrt(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
339
    float result;
Muyang Li's avatar
Muyang Li committed
340
    asm("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
341
342
343
    return result;
}

Muyang Li's avatar
Muyang Li committed
344
__device__ __forceinline__ static float cuda_sin(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
345
    float result;
Muyang Li's avatar
Muyang Li committed
346
    asm("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
347
348
349
    return result;
}

Muyang Li's avatar
Muyang Li committed
350
__device__ __forceinline__ static float cuda_cos(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
351
    float result;
Muyang Li's avatar
Muyang Li committed
352
    asm("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
353
354
355
    return result;
}

Muyang Li's avatar
Muyang Li committed
356
__device__ __forceinline__ static float cuda_exp2(float x) {
357
    float result;
Muyang Li's avatar
Muyang Li committed
358
    asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
359
360
361
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
362
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
Muyang Li's avatar
Muyang Li committed
363
__forceinline__ __device__ static float cuda_sigmoidf(float a) {
Zhekai Zhang's avatar
Zhekai Zhang committed
364
#if USE_TANH
Muyang Li's avatar
Muyang Li committed
365
366
    return fmaf(0.5, __tanhf(0.5f * a), 0.5f);
#else  // USE_TANH
Zhekai Zhang's avatar
Zhekai Zhang committed
367
368
369
    const float L2E = 1.442695041f; // log2(exp(1))
    float t, d, e, r;
    t = -L2E * a;
Muyang Li's avatar
Muyang Li committed
370
    asm("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
Zhekai Zhang's avatar
Zhekai Zhang committed
371
    d = e + 1.0f;
Muyang Li's avatar
Muyang Li committed
372
    asm("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
Zhekai Zhang's avatar
Zhekai Zhang committed
373
374
375
376
377
    return r;
#endif // USE_TANH
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
378
__device__ __forceinline__ static T gelu_half2(T x) {
muyangli's avatar
muyangli committed
379
    float2 xf  = half22float2(x);
Zhekai Zhang's avatar
Zhekai Zhang committed
380
    float2 x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
381
382
    float t1   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
    float t2   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
Zhekai Zhang's avatar
Zhekai Zhang committed
383
384
385
386
    return float22half2<T>(xf * make_float2(t1, t2));
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
387
__device__ __forceinline__ static T gelu_half(T x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
388
389
    float xf  = float(x);
    float x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
390
    float t   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
Zhekai Zhang's avatar
Zhekai Zhang committed
391
392
393
    return (T)(xf * t);
}

Muyang Li's avatar
Muyang Li committed
394
395
396
397
398
template<typename T>
__device__ __forceinline__ static T silu(const T &x) {
    // x * sigmoid(x)
    return (T)((float)x * cuda_sigmoidf((float)x));
    // return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
Zhekai Zhang's avatar
Zhekai Zhang committed
399
400
}

Muyang Li's avatar
Muyang Li committed
401
__device__ __forceinline__ static half2 h2div(half2 a, half2 b) {
muyangli's avatar
muyangli committed
402
403
404
405
406
407
408
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<half2>(of);
};
fengzch-das's avatar
fengzch-das committed
409
__device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
muyangli's avatar
muyangli committed
410
411
412
413
414
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
fengzch-das's avatar
fengzch-das committed
415
    return float22half2<__nv_bfloat162>(of);
muyangli's avatar
muyangli committed
416
417
};

Muyang Li's avatar
Muyang Li committed
418
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
fengzch's avatar
fengzch committed
419
    asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
Zhekai Zhang's avatar
Zhekai Zhang committed
420
421
}

Muyang Li's avatar
Muyang Li committed
422
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
fengzch's avatar
fengzch committed
423
424
425
426
427
    asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                 "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
                 "}" ::"r"((int)pred),
                 "l"(addr),
                 "f"(val));
sxtyzhangzk's avatar
sxtyzhangzk committed
428
429
}

Zhekai Zhang's avatar
Zhekai Zhang committed
430
template<int cnt, typename F>
Muyang Li's avatar
Muyang Li committed
431
432
__device__ __forceinline__ static void unrolled_loop(F &&lambda) {
    auto call = [&]<int... Is>(std::integer_sequence<int, Is...>) { (lambda.template operator()<Is>(), ...); };
Zhekai Zhang's avatar
Zhekai Zhang committed
433
    call(std::make_integer_sequence<int, cnt>());
muyangli's avatar
muyangli committed
434
435
}

436
437
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
Muyang Li's avatar
Muyang Li committed
438
__device__ __forceinline__ static float int2float_fast(int val) {
fengzch's avatar
fengzch committed
439
440
    // float fval;
    // // fval = (val & 0x7FFFFF) ^ 0x4B400000
fengzch's avatar
fengzch committed
441
442
443
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
    //              : "=f"(fval)
    //              : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
fengzch's avatar
fengzch committed
444
445
446
447
    unsigned int temp = (val & 0x7FFFFF) ^ 0x4B400000;
    float result;
    memcpy(&result, &temp, sizeof(float));
    return result - 12582912.0f;
448
449
}

450
template<typename To, typename From>
Muyang Li's avatar
Muyang Li committed
451
__device__ __forceinline__ static To bit_cast(const From &input) {
452
453
454
455
456
    static_assert(sizeof(To) == sizeof(From));
    // not safe but anyway
    return *reinterpret_cast<const To *>(&input);
}

457
458
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
Muyang Li's avatar
Muyang Li committed
459
__device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
460
461
462
    uint32_t ival;
    uint32_t hval;
    // ival.lo = x.lo; ival.hi = y.lo;
fengzch's avatar
fengzch committed
463
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
464
465
    ival = ival >> 4;
    // (val & 0x03FF03FF) ^ 0x76007600
fengzch's avatar
fengzch committed
466
467
468
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
469
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
470
471
}
// val in [-4096, 4095], steps of 8, round to nearest
Muyang Li's avatar
Muyang Li committed
472
__device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
473
474
475
476
477
478
479
    // x = max(min(x, 4095), -4096);
    // y = max(min(y, 4095), -4096);
    // TODO: round to even?
    x = x * 8192 + 32768;
    y = y * 8192 + 32768;
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
480
    // ival.lo = x.hi; ival.hi = y.hi;
481
    // <=> divide x and y by 65536 and pack them
fengzch's avatar
fengzch committed
482
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
483
    // (val & 0x03FF03FF) ^ 0x72007200
fengzch's avatar
fengzch committed
484
485
486
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
487
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
488
489
}
// val in [-512, 511]
Muyang Li's avatar
Muyang Li committed
490
__device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
491
492
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
493
    // ival.lo = x.lo; ival.hi = y.lo;
494
    // <=> divide x and y by 65536 and pack them
fengzch's avatar
fengzch committed
495
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
496
    // (val & 0x03FF03FF) ^ 0x66006600
fengzch's avatar
fengzch committed
497
498
499
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
500
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
501
502
}

Muyang Li's avatar
Muyang Li committed
503
}; // namespace nunchaku::kernels