gemm_utils.cuh 14.7 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#pragma once

#include <cstdint>
#include "common.h"
muyangli's avatar
muyangli committed
5
6
7
#include "../utils.cuh"

namespace nunchaku::kernels {
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9

static constexpr int clamp(int val, int min, int max) {
Muyang Li's avatar
Muyang Li committed
10
    if (val < min)
Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
14
15
16
17
        return min;
    if (val > max)
        return max;
    return val;
}

template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
18
__device__ __forceinline__ static T load(const T *addr) {
Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data;
Muyang Li's avatar
Muyang Li committed
22
23
24
            asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];"
                         : "=r"(data.x), "=r"(data.y)
                         : "l"(__cvta_generic_to_shared(addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
28
            return *reinterpret_cast<T *>(&data);
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data;
Muyang Li's avatar
Muyang Li committed
29
30
31
            asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
                         : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
                         : "l"(__cvta_generic_to_shared(addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            return *reinterpret_cast<T *>(&data);
        }
        return *addr;
    }

    if constexpr (sizeof(T) == 8) {
        uint2 data = __ldg(reinterpret_cast<const uint2 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = __ldg(reinterpret_cast<const uint4 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }

    return *addr;
}

49
template<typename T>
Muyang Li's avatar
Muyang Li committed
50
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
51
52
    if constexpr (sizeof(T) == 4) {
        uint32_t data;
Muyang Li's avatar
Muyang Li committed
53
54
55
56
57
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
                     "@loadpred ld.global.nc.b32 %0, [%1];"
                     "}"
                     : "=r"(data)
                     : "l"(addr), "r"((int)pred));
58
59
60
61
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data;
Muyang Li's avatar
Muyang Li committed
62
63
64
65
66
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
                     "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
                     "}"
                     : "=r"(data.x), "=r"(data.y)
                     : "l"(addr), "r"((int)pred));
67
68
69
70
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data;
Muyang Li's avatar
Muyang Li committed
71
72
73
74
75
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
                     "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
                     "}"
                     : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
                     : "l"(addr), "r"((int)pred));
76
77
78
79
80
81
82
83
84
85
        return *reinterpret_cast<T *>(&data);
    }

    T result;
    if (pred) {
        result = *addr;
    }
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
86
template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
87
__device__ __forceinline__ static void store(T *addr, T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
88
89
90
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data = *reinterpret_cast<uint2 *>(&val);
Muyang Li's avatar
Muyang Li committed
91
92
            asm volatile(
                "st.shared.v2.b32 [%0], {%1, %2};" ::"l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y));
Zhekai Zhang's avatar
Zhekai Zhang committed
93
94
95
96
            return;
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data = *reinterpret_cast<uint4 *>(&val);
Muyang Li's avatar
Muyang Li committed
97
98
99
100
101
            asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(__cvta_generic_to_shared(addr)),
                         "r"(data.x),
                         "r"(data.y),
                         "r"(data.z),
                         "r"(data.w));
Zhekai Zhang's avatar
Zhekai Zhang committed
102
103
104
105
106
107
108
            return;
        }
        *addr = val;
        return;
    }

    if constexpr (sizeof(T) == 4) {
sxtyzhangzk's avatar
sxtyzhangzk committed
109
        __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
Zhekai Zhang's avatar
Zhekai Zhang committed
110
111
112
113
114
115
116
117
118
        return;
    }
    if constexpr (sizeof(T) == 8) {
        __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
        return;
    }
    if constexpr (sizeof(T) == 16) {
        __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
        return;
Muyang Li's avatar
Muyang Li committed
119
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
120
121
122
    *addr = val;
}

123
template<typename T>
Muyang Li's avatar
Muyang Li committed
124
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
125
126
    if constexpr (sizeof(T) == 4) {
        uint32_t data = *reinterpret_cast<uint32_t *>(&val);
Muyang Li's avatar
Muyang Li committed
127
128
129
130
131
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.b32 [%1], %2;"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data));
132
133
134
135
        return;
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data = *reinterpret_cast<uint2 *>(&val);
Muyang Li's avatar
Muyang Li committed
136
137
138
139
140
141
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y));
142
143
144
145
        return;
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = *reinterpret_cast<uint4 *>(&val);
Muyang Li's avatar
Muyang Li committed
146
147
148
149
150
151
152
153
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y),
                     "r"(data.z),
                     "r"(data.w));
154
155
156
157
158
159
160
161
        return;
    }

    if (pred) {
        *addr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
162
__device__ __forceinline__ static float2 half22float2(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
163
164
165
    return __half22float2(val);
}

Muyang Li's avatar
Muyang Li committed
166
__device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
167
168
169
170
    return __bfloat1622float2(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
171
__device__ __forceinline__ static T float22half2(float2 val) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
172
173

template<>
Muyang Li's avatar
Muyang Li committed
174
__device__ __forceinline__ half2 float22half2<half2>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
175
176
177
178
    return __float22half2_rn(val);
}

template<>
Muyang Li's avatar
Muyang Li committed
179
__device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
180
181
182
183
    return __float22bfloat162_rn(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
184
__device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
Zhekai Zhang's avatar
Zhekai Zhang committed
185
186
187
188
189
190
    volatile T *ptr = nullptr;
    if (alwaysfalse) {
        *ptr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
191
192
193
194
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
    asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
                 : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
                 : "l"(__cvta_generic_to_shared(ptr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
195
196
}

197
template<typename T>
Muyang Li's avatar
Muyang Li committed
198
199
200
201
__device__ __forceinline__ static T movmatrix(T x) {
    asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
                 : "=r"(*reinterpret_cast<uint32_t *>(&x))
                 : "r"(*reinterpret_cast<uint32_t *>(&x)));
202
203
204
    return x;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
205
206
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
Muyang Li's avatar
Muyang Li committed
207
__device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
208
209

template<>
Muyang Li's avatar
Muyang Li committed
210
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
211
212
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
213
214
215
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
216
217
218
219
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
220
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
221
222
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
223
224
225
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
226
227
228
229
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
230
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
231
232
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
233
234
235
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
236
237
238
    return result;
}

Muyang Li's avatar
Muyang Li committed
239
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
240
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
241
242
243
    asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
                 : "=r"(result)
                 : "f"(value.y), "f"(value.x));
244
245
246
    return result;
}

Muyang Li's avatar
Muyang Li committed
247
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
248
    uint16_t lo, hi;
Muyang Li's avatar
Muyang Li committed
249
250
    asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
    asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
251
252
253
    return uint32_t(lo) | (uint32_t(hi) << 16);
}

Muyang Li's avatar
Muyang Li committed
254
__device__ __forceinline__ static float cuda_tanhf(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
255
    float result;
Muyang Li's avatar
Muyang Li committed
256
    asm("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
257
258
259
    return result;
}

Muyang Li's avatar
Muyang Li committed
260
__device__ __forceinline__ static float cuda_frcp(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
261
    float result;
Muyang Li's avatar
Muyang Li committed
262
    asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
263
264
265
    return result;
}

Muyang Li's avatar
Muyang Li committed
266
__device__ __forceinline__ static float cuda_frsqrt(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
267
    float result;
Muyang Li's avatar
Muyang Li committed
268
    asm("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
269
270
271
    return result;
}

Muyang Li's avatar
Muyang Li committed
272
__device__ __forceinline__ static float cuda_sin(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
273
    float result;
Muyang Li's avatar
Muyang Li committed
274
    asm("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
275
276
277
    return result;
}

Muyang Li's avatar
Muyang Li committed
278
__device__ __forceinline__ static float cuda_cos(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
279
    float result;
Muyang Li's avatar
Muyang Li committed
280
    asm("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
281
282
283
    return result;
}

Muyang Li's avatar
Muyang Li committed
284
__device__ __forceinline__ static float cuda_exp2(float x) {
285
    float result;
Muyang Li's avatar
Muyang Li committed
286
    asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
287
288
289
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
290
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
Muyang Li's avatar
Muyang Li committed
291
__forceinline__ __device__ static float cuda_sigmoidf(float a) {
Zhekai Zhang's avatar
Zhekai Zhang committed
292
#if USE_TANH
Muyang Li's avatar
Muyang Li committed
293
294
    return fmaf(0.5, __tanhf(0.5f * a), 0.5f);
#else  // USE_TANH
Zhekai Zhang's avatar
Zhekai Zhang committed
295
296
297
    const float L2E = 1.442695041f; // log2(exp(1))
    float t, d, e, r;
    t = -L2E * a;
Muyang Li's avatar
Muyang Li committed
298
    asm("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
Zhekai Zhang's avatar
Zhekai Zhang committed
299
    d = e + 1.0f;
Muyang Li's avatar
Muyang Li committed
300
    asm("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
Zhekai Zhang's avatar
Zhekai Zhang committed
301
302
303
304
305
    return r;
#endif // USE_TANH
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
306
__device__ __forceinline__ static T gelu_half2(T x) {
muyangli's avatar
muyangli committed
307
    float2 xf  = half22float2(x);
Zhekai Zhang's avatar
Zhekai Zhang committed
308
    float2 x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
309
310
    float t1   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
    float t2   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
Zhekai Zhang's avatar
Zhekai Zhang committed
311
312
313
314
    return float22half2<T>(xf * make_float2(t1, t2));
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
315
__device__ __forceinline__ static T gelu_half(T x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
316
317
    float xf  = float(x);
    float x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
318
    float t   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
Zhekai Zhang's avatar
Zhekai Zhang committed
319
320
321
    return (T)(xf * t);
}

Muyang Li's avatar
Muyang Li committed
322
323
324
325
326
template<typename T>
__device__ __forceinline__ static T silu(const T &x) {
    // x * sigmoid(x)
    return (T)((float)x * cuda_sigmoidf((float)x));
    // return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
Zhekai Zhang's avatar
Zhekai Zhang committed
327
328
}

Muyang Li's avatar
Muyang Li committed
329
__device__ __forceinline__ static half2 h2div(half2 a, half2 b) {
muyangli's avatar
muyangli committed
330
331
332
333
334
335
336
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<half2>(of);
};
Muyang Li's avatar
Muyang Li committed
337
__device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
muyangli's avatar
muyangli committed
338
339
340
341
342
343
344
345
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<__nv_bfloat162>(of);
};

Muyang Li's avatar
Muyang Li committed
346
347
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
    asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
Zhekai Zhang's avatar
Zhekai Zhang committed
348
349
}

Muyang Li's avatar
Muyang Li committed
350
351
352
353
354
355
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
    asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                 "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
                 "}" ::"r"((int)pred),
                 "l"(addr),
                 "f"(val));
sxtyzhangzk's avatar
sxtyzhangzk committed
356
357
}

Zhekai Zhang's avatar
Zhekai Zhang committed
358
template<int cnt, typename F>
Muyang Li's avatar
Muyang Li committed
359
360
__device__ __forceinline__ static void unrolled_loop(F &&lambda) {
    auto call = [&]<int... Is>(std::integer_sequence<int, Is...>) { (lambda.template operator()<Is>(), ...); };
Zhekai Zhang's avatar
Zhekai Zhang committed
361
    call(std::make_integer_sequence<int, cnt>());
muyangli's avatar
muyangli committed
362
363
}

364
365
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
Muyang Li's avatar
Muyang Li committed
366
__device__ __forceinline__ static float int2float_fast(int val) {
367
368
    float fval;
    // fval = (val & 0x7FFFFF) ^ 0x4B400000
Muyang Li's avatar
Muyang Li committed
369
370
371
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=f"(fval)
                 : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
372
373
374
    return fval - 12582912.0f;
}

375
template<typename To, typename From>
Muyang Li's avatar
Muyang Li committed
376
__device__ __forceinline__ static To bit_cast(const From &input) {
377
378
379
380
381
    static_assert(sizeof(To) == sizeof(From));
    // not safe but anyway
    return *reinterpret_cast<const To *>(&input);
}

382
383
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
Muyang Li's avatar
Muyang Li committed
384
__device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
385
386
387
    uint32_t ival;
    uint32_t hval;
    // ival.lo = x.lo; ival.hi = y.lo;
Muyang Li's avatar
Muyang Li committed
388
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
389
390
    ival = ival >> 4;
    // (val & 0x03FF03FF) ^ 0x76007600
Muyang Li's avatar
Muyang Li committed
391
392
393
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
394
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
395
396
}
// val in [-4096, 4095], steps of 8, round to nearest
Muyang Li's avatar
Muyang Li committed
397
__device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
398
399
400
401
402
403
404
    // x = max(min(x, 4095), -4096);
    // y = max(min(y, 4095), -4096);
    // TODO: round to even?
    x = x * 8192 + 32768;
    y = y * 8192 + 32768;
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
405
    // ival.lo = x.hi; ival.hi = y.hi;
406
    // <=> divide x and y by 65536 and pack them
Muyang Li's avatar
Muyang Li committed
407
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
408
    // (val & 0x03FF03FF) ^ 0x72007200
Muyang Li's avatar
Muyang Li committed
409
410
411
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
412
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
413
414
}
// val in [-512, 511]
Muyang Li's avatar
Muyang Li committed
415
__device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
416
417
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
418
    // ival.lo = x.lo; ival.hi = y.lo;
419
    // <=> divide x and y by 65536 and pack them
Muyang Li's avatar
Muyang Li committed
420
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
421
    // (val & 0x03FF03FF) ^ 0x66006600
Muyang Li's avatar
Muyang Li committed
422
423
424
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
425
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
426
427
}

Muyang Li's avatar
Muyang Li committed
428
}; // namespace nunchaku::kernels