gemm_utils.cuh 16.8 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#pragma once

#include <cstdint>
#include "common.h"
muyangli's avatar
muyangli committed
5
6
7
#include "../utils.cuh"

namespace nunchaku::kernels {
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9

static constexpr int clamp(int val, int min, int max) {
Muyang Li's avatar
Muyang Li committed
10
    if (val < min)
Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
14
15
16
17
        return min;
    if (val > max)
        return max;
    return val;
}

template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
18
__device__ __forceinline__ static T load(const T *addr) {
Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data;
Muyang Li's avatar
Muyang Li committed
22
23
            asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];"
                         : "=r"(data.x), "=r"(data.y)
fengzch's avatar
fengzch committed
24
                         : "l"((addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
28
            return *reinterpret_cast<T *>(&data);
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data;
Muyang Li's avatar
Muyang Li committed
29
30
            asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
                         : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
fengzch's avatar
fengzch committed
31
                         : "l"((addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            return *reinterpret_cast<T *>(&data);
        }
        return *addr;
    }

    if constexpr (sizeof(T) == 8) {
        uint2 data = __ldg(reinterpret_cast<const uint2 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = __ldg(reinterpret_cast<const uint4 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }

    return *addr;
}

49
template<typename T>
Muyang Li's avatar
Muyang Li committed
50
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
51
52
    if constexpr (sizeof(T) == 4) {
        uint32_t data;
fengzch's avatar
fengzch committed
53
54
55
56
57
58
        // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
        //              "@loadpred ld.global.nc.b32 %0, [%1];"
        //              "}"
        //              : "=r"(data)
        //              : "l"(addr), "r"((int)pred));
        // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
59
60
61
62
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data;
fengzch's avatar
fengzch committed
63
64
65
66
67
68
        // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
        //              "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
        //              "}"
        //              : "=r"(data.x), "=r"(data.y)
        //              : "l"(addr), "r"((int)pred));
        // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
69
70
71
72
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data;
fengzch's avatar
fengzch committed
73
74
75
76
77
78
79
        // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
        //              "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
        //              "}"
        //              : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
        //              : "l"(addr), "r"((int)pred));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);

80
81
82
83
84
85
86
87
88
89
        return *reinterpret_cast<T *>(&data);
    }

    T result;
    if (pred) {
        result = *addr;
    }
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
90
template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
91
__device__ __forceinline__ static void store(T *addr, T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
92
93
94
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data = *reinterpret_cast<uint2 *>(&val);
fengzch's avatar
fengzch committed
95
96
97
98
            // asm volatile(
            //     "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
            
Zhekai Zhang's avatar
Zhekai Zhang committed
99
100
101
102
            return;
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data = *reinterpret_cast<uint4 *>(&val);
fengzch's avatar
fengzch committed
103
104
105
106
107
108
109
            // asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
            //              "r"(data.x),
            //              "r"(data.y),
            //              "r"(data.z),
            //              "r"(data.w));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);

Zhekai Zhang's avatar
Zhekai Zhang committed
110
111
112
113
114
115
116
            return;
        }
        *addr = val;
        return;
    }

    if constexpr (sizeof(T) == 4) {
limm's avatar
limm committed
117
118
        // __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
	*reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
119
120
121
        return;
    }
    if constexpr (sizeof(T) == 8) {
limm's avatar
limm committed
122
123
        // __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
	*reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
124
125
126
        return;
    }
    if constexpr (sizeof(T) == 16) {
limm's avatar
limm committed
127
128
        // __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
	*reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
129
        return;
Muyang Li's avatar
Muyang Li committed
130
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
131
132
133
    *addr = val;
}

134
template<typename T>
Muyang Li's avatar
Muyang Li committed
135
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
136
137
    if constexpr (sizeof(T) == 4) {
        uint32_t data = *reinterpret_cast<uint32_t *>(&val);
fengzch's avatar
fengzch committed
138
139
140
141
142
143
144
        // asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
        //              "@storepred st.global.cg.b32 [%1], %2;"
        //              "}" ::"r"((int)pred),
        //              "l"(addr),
        //              "r"(data));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);

145
146
147
148
        return;
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data = *reinterpret_cast<uint2 *>(&val);
fengzch's avatar
fengzch committed
149
150
151
152
153
154
155
156
        // asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
        //              "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
        //              "}" ::"r"((int)pred),
        //              "l"(addr),
        //              "r"(data.x),
        //              "r"(data.y));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);

157
158
159
160
        return;
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = *reinterpret_cast<uint4 *>(&val);
fengzch's avatar
fengzch committed
161
162
163
164
165
166
167
168
169
170
        // asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
        //              "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
        //              "}" ::"r"((int)pred),
        //              "l"(addr),
        //              "r"(data.x),
        //              "r"(data.y),
        //              "r"(data.z),
        //              "r"(data.w));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);

171
172
173
174
175
176
177
178
        return;
    }

    if (pred) {
        *addr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
179
__device__ __forceinline__ static float2 half22float2(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
180
181
182
    return __half22float2(val);
}

fengzch-das's avatar
fengzch-das committed
183
__device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
184
185
186
187
    return __bfloat1622float2(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
188
__device__ __forceinline__ static T float22half2(float2 val) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
189
190

template<>
Muyang Li's avatar
Muyang Li committed
191
__device__ __forceinline__ half2 float22half2<half2>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
192
193
194
195
    return __float22half2_rn(val);
}

template<>
fengzch-das's avatar
fengzch-das committed
196
__device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
197
198
199
200
    return __float22bfloat162_rn(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
201
__device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
Zhekai Zhang's avatar
Zhekai Zhang committed
202
203
204
205
206
207
    volatile T *ptr = nullptr;
    if (alwaysfalse) {
        *ptr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
208
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
fengzch's avatar
fengzch committed
209
210
    asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
                 : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
fengzch's avatar
fengzch committed
211
                 : "l"((ptr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
212
213
}

214
template<typename T>
Muyang Li's avatar
Muyang Li committed
215
__device__ __forceinline__ static T movmatrix(T x) {
fengzch's avatar
fengzch committed
216
217
218
219
220
    // asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
    //              : "=r"(*reinterpret_cast<uint32_t *>(&x))
    //              : "r"(*reinterpret_cast<uint32_t *>(&x)));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);

221
222
223
    return x;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
224
225
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
Muyang Li's avatar
Muyang Li committed
226
__device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
227
228

template<>
Muyang Li's avatar
Muyang Li committed
229
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
230
231
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
232
233
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
234
235
236
    // asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
    
Zhekai Zhang's avatar
Zhekai Zhang committed
237
238
239
240
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
241
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
242
243
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
244
245
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
246
247
248
    // asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
    
Zhekai Zhang's avatar
Zhekai Zhang committed
249
250
251
252
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
253
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
254
255
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
256
257
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
fengzch's avatar
fengzch committed
258
259
260
    // asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
    
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
Zhekai Zhang's avatar
Zhekai Zhang committed
261
262
263
    return result;
}

Muyang Li's avatar
Muyang Li committed
264
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
265
    uint32_t result;
fengzch's avatar
fengzch committed
266
267
268
269
270
    // asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
    //              : "=r"(result)
    //              : "f"(value.y), "f"(value.x));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);

271
272
273
    return result;
}

Muyang Li's avatar
Muyang Li committed
274
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
275
    uint16_t lo, hi;
fengzch's avatar
fengzch committed
276
277
278
    // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
    // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
279
280
281
    return uint32_t(lo) | (uint32_t(hi) << 16);
}

Muyang Li's avatar
Muyang Li committed
282
__device__ __forceinline__ static float cuda_tanhf(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
283
    float result;
Muyang Li's avatar
Muyang Li committed
284
    asm("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
285
286
287
    return result;
}

Muyang Li's avatar
Muyang Li committed
288
__device__ __forceinline__ static float cuda_frcp(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
289
    float result;
Muyang Li's avatar
Muyang Li committed
290
    asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
291
292
293
    return result;
}

Muyang Li's avatar
Muyang Li committed
294
__device__ __forceinline__ static float cuda_frsqrt(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
295
    float result;
Muyang Li's avatar
Muyang Li committed
296
    asm("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
297
298
299
    return result;
}

Muyang Li's avatar
Muyang Li committed
300
__device__ __forceinline__ static float cuda_sin(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
301
    float result;
Muyang Li's avatar
Muyang Li committed
302
    asm("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
303
304
305
    return result;
}

Muyang Li's avatar
Muyang Li committed
306
__device__ __forceinline__ static float cuda_cos(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
307
    float result;
Muyang Li's avatar
Muyang Li committed
308
    asm("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
309
310
311
    return result;
}

Muyang Li's avatar
Muyang Li committed
312
__device__ __forceinline__ static float cuda_exp2(float x) {
313
    float result;
Muyang Li's avatar
Muyang Li committed
314
    asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
315
316
317
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
318
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
Muyang Li's avatar
Muyang Li committed
319
__forceinline__ __device__ static float cuda_sigmoidf(float a) {
Zhekai Zhang's avatar
Zhekai Zhang committed
320
#if USE_TANH
Muyang Li's avatar
Muyang Li committed
321
322
    return fmaf(0.5, __tanhf(0.5f * a), 0.5f);
#else  // USE_TANH
Zhekai Zhang's avatar
Zhekai Zhang committed
323
324
325
    const float L2E = 1.442695041f; // log2(exp(1))
    float t, d, e, r;
    t = -L2E * a;
Muyang Li's avatar
Muyang Li committed
326
    asm("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
Zhekai Zhang's avatar
Zhekai Zhang committed
327
    d = e + 1.0f;
Muyang Li's avatar
Muyang Li committed
328
    asm("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
Zhekai Zhang's avatar
Zhekai Zhang committed
329
330
331
332
333
    return r;
#endif // USE_TANH
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
334
__device__ __forceinline__ static T gelu_half2(T x) {
muyangli's avatar
muyangli committed
335
    float2 xf  = half22float2(x);
Zhekai Zhang's avatar
Zhekai Zhang committed
336
    float2 x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
337
338
    float t1   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
    float t2   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
Zhekai Zhang's avatar
Zhekai Zhang committed
339
340
341
342
    return float22half2<T>(xf * make_float2(t1, t2));
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
343
__device__ __forceinline__ static T gelu_half(T x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
344
345
    float xf  = float(x);
    float x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
346
    float t   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
Zhekai Zhang's avatar
Zhekai Zhang committed
347
348
349
    return (T)(xf * t);
}

Muyang Li's avatar
Muyang Li committed
350
351
352
353
354
template<typename T>
__device__ __forceinline__ static T silu(const T &x) {
    // x * sigmoid(x)
    return (T)((float)x * cuda_sigmoidf((float)x));
    // return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
Zhekai Zhang's avatar
Zhekai Zhang committed
355
356
}

Muyang Li's avatar
Muyang Li committed
357
__device__ __forceinline__ static half2 h2div(half2 a, half2 b) {
muyangli's avatar
muyangli committed
358
359
360
361
362
363
364
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<half2>(of);
};
fengzch-das's avatar
fengzch-das committed
365
__device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
muyangli's avatar
muyangli committed
366
367
368
369
370
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
fengzch-das's avatar
fengzch-das committed
371
    return float22half2<__nv_bfloat162>(of);
muyangli's avatar
muyangli committed
372
373
};

Muyang Li's avatar
Muyang Li committed
374
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
fengzch's avatar
fengzch committed
375
376
    // asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
Zhekai Zhang's avatar
Zhekai Zhang committed
377
378
}

Muyang Li's avatar
Muyang Li committed
379
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
fengzch's avatar
fengzch committed
380
381
382
383
384
385
    // asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
    //              "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
    //              "}" ::"r"((int)pred),
    //              "l"(addr),
    //              "f"(val));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
sxtyzhangzk's avatar
sxtyzhangzk committed
386
387
}

Zhekai Zhang's avatar
Zhekai Zhang committed
388
template<int cnt, typename F>
Muyang Li's avatar
Muyang Li committed
389
390
__device__ __forceinline__ static void unrolled_loop(F &&lambda) {
    auto call = [&]<int... Is>(std::integer_sequence<int, Is...>) { (lambda.template operator()<Is>(), ...); };
Zhekai Zhang's avatar
Zhekai Zhang committed
391
    call(std::make_integer_sequence<int, cnt>());
muyangli's avatar
muyangli committed
392
393
}

394
395
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
Muyang Li's avatar
Muyang Li committed
396
__device__ __forceinline__ static float int2float_fast(int val) {
397
398
    float fval;
    // fval = (val & 0x7FFFFF) ^ 0x4B400000
fengzch's avatar
fengzch committed
399
400
401
402
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
    //              : "=f"(fval)
    //              : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
403
404
405
    return fval - 12582912.0f;
}

406
template<typename To, typename From>
Muyang Li's avatar
Muyang Li committed
407
__device__ __forceinline__ static To bit_cast(const From &input) {
408
409
410
411
412
    static_assert(sizeof(To) == sizeof(From));
    // not safe but anyway
    return *reinterpret_cast<const To *>(&input);
}

413
414
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
Muyang Li's avatar
Muyang Li committed
415
__device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
416
417
418
    uint32_t ival;
    uint32_t hval;
    // ival.lo = x.lo; ival.hi = y.lo;
fengzch's avatar
fengzch committed
419
    // asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
420
421
    ival = ival >> 4;
    // (val & 0x03FF03FF) ^ 0x76007600
fengzch's avatar
fengzch committed
422
423
424
425
   //  asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
   //               : "=r"(hval)
   //               : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
sxtyzhangzk's avatar
sxtyzhangzk committed
426
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
427
428
}
// val in [-4096, 4095], steps of 8, round to nearest
Muyang Li's avatar
Muyang Li committed
429
__device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
430
431
432
433
434
435
436
    // x = max(min(x, 4095), -4096);
    // y = max(min(y, 4095), -4096);
    // TODO: round to even?
    x = x * 8192 + 32768;
    y = y * 8192 + 32768;
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
437
    // ival.lo = x.hi; ival.hi = y.hi;
438
    // <=> divide x and y by 65536 and pack them
fengzch's avatar
fengzch committed
439
    // asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
440
    // (val & 0x03FF03FF) ^ 0x72007200
fengzch's avatar
fengzch committed
441
442
443
444
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
    //              : "=r"(hval)
    //              : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
sxtyzhangzk's avatar
sxtyzhangzk committed
445
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
446
447
}
// val in [-512, 511]
Muyang Li's avatar
Muyang Li committed
448
__device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
449
450
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
451
    // ival.lo = x.lo; ival.hi = y.lo;
452
    // <=> divide x and y by 65536 and pack them
fengzch's avatar
fengzch committed
453
    //asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
454
    // (val & 0x03FF03FF) ^ 0x66006600
fengzch's avatar
fengzch committed
455
456
457
458
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
    //              : "=r"(hval)
    //              : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
    // printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
sxtyzhangzk's avatar
sxtyzhangzk committed
459
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
460
461
}

Muyang Li's avatar
Muyang Li committed
462
}; // namespace nunchaku::kernels