gemm_utils.cuh 14.9 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#pragma once

#include <cstdint>
#include "common.h"
muyangli's avatar
muyangli committed
5
6
7
#include "../utils.cuh"

namespace nunchaku::kernels {
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9

static constexpr int clamp(int val, int min, int max) {
Muyang Li's avatar
Muyang Li committed
10
    if (val < min)
Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
14
15
16
17
        return min;
    if (val > max)
        return max;
    return val;
}

template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
18
__device__ __forceinline__ static T load(const T *addr) {
Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data;
Muyang Li's avatar
Muyang Li committed
22
23
24
            asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];"
                         : "=r"(data.x), "=r"(data.y)
                         : "l"(__cvta_generic_to_shared(addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
28
            return *reinterpret_cast<T *>(&data);
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data;
Muyang Li's avatar
Muyang Li committed
29
30
31
            asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
                         : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
                         : "l"(__cvta_generic_to_shared(addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            return *reinterpret_cast<T *>(&data);
        }
        return *addr;
    }

    if constexpr (sizeof(T) == 8) {
        uint2 data = __ldg(reinterpret_cast<const uint2 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = __ldg(reinterpret_cast<const uint4 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }

    return *addr;
}

49
template<typename T>
Muyang Li's avatar
Muyang Li committed
50
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
51
52
    if constexpr (sizeof(T) == 4) {
        uint32_t data;
Muyang Li's avatar
Muyang Li committed
53
54
55
56
57
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
                     "@loadpred ld.global.nc.b32 %0, [%1];"
                     "}"
                     : "=r"(data)
                     : "l"(addr), "r"((int)pred));
58
59
60
61
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data;
Muyang Li's avatar
Muyang Li committed
62
63
64
65
66
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
                     "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
                     "}"
                     : "=r"(data.x), "=r"(data.y)
                     : "l"(addr), "r"((int)pred));
67
68
69
70
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data;
Muyang Li's avatar
Muyang Li committed
71
72
73
74
75
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
                     "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
                     "}"
                     : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
                     : "l"(addr), "r"((int)pred));
76
77
78
79
80
81
82
83
84
85
        return *reinterpret_cast<T *>(&data);
    }

    T result;
    if (pred) {
        result = *addr;
    }
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
86
template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
87
__device__ __forceinline__ static void store(T *addr, T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
88
89
90
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data = *reinterpret_cast<uint2 *>(&val);
Muyang Li's avatar
Muyang Li committed
91
92
            asm volatile(
                "st.shared.v2.b32 [%0], {%1, %2};" ::"l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y));
Zhekai Zhang's avatar
Zhekai Zhang committed
93
94
95
96
            return;
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data = *reinterpret_cast<uint4 *>(&val);
Muyang Li's avatar
Muyang Li committed
97
98
99
100
101
            asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(__cvta_generic_to_shared(addr)),
                         "r"(data.x),
                         "r"(data.y),
                         "r"(data.z),
                         "r"(data.w));
Zhekai Zhang's avatar
Zhekai Zhang committed
102
103
104
105
106
107
108
            return;
        }
        *addr = val;
        return;
    }

    if constexpr (sizeof(T) == 4) {
limm's avatar
limm committed
109
110
        // __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
	*reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
111
112
113
        return;
    }
    if constexpr (sizeof(T) == 8) {
limm's avatar
limm committed
114
115
        // __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
	*reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
116
117
118
        return;
    }
    if constexpr (sizeof(T) == 16) {
limm's avatar
limm committed
119
120
        // __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
	*reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
121
        return;
Muyang Li's avatar
Muyang Li committed
122
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
123
124
125
    *addr = val;
}

126
template<typename T>
Muyang Li's avatar
Muyang Li committed
127
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
128
129
    if constexpr (sizeof(T) == 4) {
        uint32_t data = *reinterpret_cast<uint32_t *>(&val);
Muyang Li's avatar
Muyang Li committed
130
131
132
133
134
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.b32 [%1], %2;"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data));
135
136
137
138
        return;
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data = *reinterpret_cast<uint2 *>(&val);
Muyang Li's avatar
Muyang Li committed
139
140
141
142
143
144
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y));
145
146
147
148
        return;
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = *reinterpret_cast<uint4 *>(&val);
Muyang Li's avatar
Muyang Li committed
149
150
151
152
153
154
155
156
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y),
                     "r"(data.z),
                     "r"(data.w));
157
158
159
160
161
162
163
164
        return;
    }

    if (pred) {
        *addr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
165
__device__ __forceinline__ static float2 half22float2(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
166
167
168
    return __half22float2(val);
}

fengzch-das's avatar
fengzch-das committed
169
__device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
170
171
172
173
    return __bfloat1622float2(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
174
__device__ __forceinline__ static T float22half2(float2 val) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
175
176

template<>
Muyang Li's avatar
Muyang Li committed
177
__device__ __forceinline__ half2 float22half2<half2>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
178
179
180
181
    return __float22half2_rn(val);
}

template<>
fengzch-das's avatar
fengzch-das committed
182
__device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
183
184
185
186
    return __float22bfloat162_rn(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
187
__device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
Zhekai Zhang's avatar
Zhekai Zhang committed
188
189
190
191
192
193
    volatile T *ptr = nullptr;
    if (alwaysfalse) {
        *ptr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
194
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
limm's avatar
limm committed
195
196
197
    // asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
    //              : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
    //              : "l"(__cvta_generic_to_shared(ptr)));  // limengmeng
Zhekai Zhang's avatar
Zhekai Zhang committed
198
199
}

200
template<typename T>
Muyang Li's avatar
Muyang Li committed
201
202
203
204
__device__ __forceinline__ static T movmatrix(T x) {
    asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
                 : "=r"(*reinterpret_cast<uint32_t *>(&x))
                 : "r"(*reinterpret_cast<uint32_t *>(&x)));
205
206
207
    return x;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
208
209
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
Muyang Li's avatar
Muyang Li committed
210
__device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
211
212

template<>
Muyang Li's avatar
Muyang Li committed
213
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
214
215
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
216
217
218
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
219
220
221
222
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
223
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
224
225
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
226
227
228
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
229
230
231
232
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
233
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
234
235
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
236
237
238
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
239
240
241
    return result;
}

Muyang Li's avatar
Muyang Li committed
242
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
243
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
244
245
246
    asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
                 : "=r"(result)
                 : "f"(value.y), "f"(value.x));
247
248
249
    return result;
}

Muyang Li's avatar
Muyang Li committed
250
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
251
    uint16_t lo, hi;
Muyang Li's avatar
Muyang Li committed
252
253
    asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
    asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
254
255
256
    return uint32_t(lo) | (uint32_t(hi) << 16);
}

Muyang Li's avatar
Muyang Li committed
257
__device__ __forceinline__ static float cuda_tanhf(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
258
    float result;
Muyang Li's avatar
Muyang Li committed
259
    asm("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
260
261
262
    return result;
}

Muyang Li's avatar
Muyang Li committed
263
__device__ __forceinline__ static float cuda_frcp(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
264
    float result;
Muyang Li's avatar
Muyang Li committed
265
    asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
266
267
268
    return result;
}

Muyang Li's avatar
Muyang Li committed
269
__device__ __forceinline__ static float cuda_frsqrt(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
270
    float result;
Muyang Li's avatar
Muyang Li committed
271
    asm("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
272
273
274
    return result;
}

Muyang Li's avatar
Muyang Li committed
275
__device__ __forceinline__ static float cuda_sin(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
276
    float result;
Muyang Li's avatar
Muyang Li committed
277
    asm("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
278
279
280
    return result;
}

Muyang Li's avatar
Muyang Li committed
281
__device__ __forceinline__ static float cuda_cos(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
282
    float result;
Muyang Li's avatar
Muyang Li committed
283
    asm("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
284
285
286
    return result;
}

Muyang Li's avatar
Muyang Li committed
287
__device__ __forceinline__ static float cuda_exp2(float x) {
288
    float result;
Muyang Li's avatar
Muyang Li committed
289
    asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
290
291
292
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
293
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
Muyang Li's avatar
Muyang Li committed
294
__forceinline__ __device__ static float cuda_sigmoidf(float a) {
Zhekai Zhang's avatar
Zhekai Zhang committed
295
#if USE_TANH
Muyang Li's avatar
Muyang Li committed
296
297
    return fmaf(0.5, __tanhf(0.5f * a), 0.5f);
#else  // USE_TANH
Zhekai Zhang's avatar
Zhekai Zhang committed
298
299
300
    const float L2E = 1.442695041f; // log2(exp(1))
    float t, d, e, r;
    t = -L2E * a;
Muyang Li's avatar
Muyang Li committed
301
    asm("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
Zhekai Zhang's avatar
Zhekai Zhang committed
302
    d = e + 1.0f;
Muyang Li's avatar
Muyang Li committed
303
    asm("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
Zhekai Zhang's avatar
Zhekai Zhang committed
304
305
306
307
308
    return r;
#endif // USE_TANH
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
309
__device__ __forceinline__ static T gelu_half2(T x) {
muyangli's avatar
muyangli committed
310
    float2 xf  = half22float2(x);
Zhekai Zhang's avatar
Zhekai Zhang committed
311
    float2 x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
312
313
    float t1   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
    float t2   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
Zhekai Zhang's avatar
Zhekai Zhang committed
314
315
316
317
    return float22half2<T>(xf * make_float2(t1, t2));
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
318
__device__ __forceinline__ static T gelu_half(T x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
319
320
    float xf  = float(x);
    float x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
321
    float t   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
Zhekai Zhang's avatar
Zhekai Zhang committed
322
323
324
    return (T)(xf * t);
}

Muyang Li's avatar
Muyang Li committed
325
326
327
328
329
template<typename T>
__device__ __forceinline__ static T silu(const T &x) {
    // x * sigmoid(x)
    return (T)((float)x * cuda_sigmoidf((float)x));
    // return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
Zhekai Zhang's avatar
Zhekai Zhang committed
330
331
}

Muyang Li's avatar
Muyang Li committed
332
__device__ __forceinline__ static half2 h2div(half2 a, half2 b) {
muyangli's avatar
muyangli committed
333
334
335
336
337
338
339
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<half2>(of);
};
fengzch-das's avatar
fengzch-das committed
340
__device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
muyangli's avatar
muyangli committed
341
342
343
344
345
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
fengzch-das's avatar
fengzch-das committed
346
    return float22half2<__nv_bfloat162>(of);
muyangli's avatar
muyangli committed
347
348
};

Muyang Li's avatar
Muyang Li committed
349
350
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
    asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
Zhekai Zhang's avatar
Zhekai Zhang committed
351
352
}

Muyang Li's avatar
Muyang Li committed
353
354
355
356
357
358
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
    asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                 "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
                 "}" ::"r"((int)pred),
                 "l"(addr),
                 "f"(val));
sxtyzhangzk's avatar
sxtyzhangzk committed
359
360
}

Zhekai Zhang's avatar
Zhekai Zhang committed
361
template<int cnt, typename F>
Muyang Li's avatar
Muyang Li committed
362
363
__device__ __forceinline__ static void unrolled_loop(F &&lambda) {
    auto call = [&]<int... Is>(std::integer_sequence<int, Is...>) { (lambda.template operator()<Is>(), ...); };
Zhekai Zhang's avatar
Zhekai Zhang committed
364
    call(std::make_integer_sequence<int, cnt>());
muyangli's avatar
muyangli committed
365
366
}

367
368
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
Muyang Li's avatar
Muyang Li committed
369
__device__ __forceinline__ static float int2float_fast(int val) {
370
371
    float fval;
    // fval = (val & 0x7FFFFF) ^ 0x4B400000
Muyang Li's avatar
Muyang Li committed
372
373
374
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=f"(fval)
                 : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
375
376
377
    return fval - 12582912.0f;
}

378
template<typename To, typename From>
Muyang Li's avatar
Muyang Li committed
379
__device__ __forceinline__ static To bit_cast(const From &input) {
380
381
382
383
384
    static_assert(sizeof(To) == sizeof(From));
    // not safe but anyway
    return *reinterpret_cast<const To *>(&input);
}

385
386
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
Muyang Li's avatar
Muyang Li committed
387
__device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
388
389
390
    uint32_t ival;
    uint32_t hval;
    // ival.lo = x.lo; ival.hi = y.lo;
Muyang Li's avatar
Muyang Li committed
391
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
392
393
    ival = ival >> 4;
    // (val & 0x03FF03FF) ^ 0x76007600
Muyang Li's avatar
Muyang Li committed
394
395
396
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
397
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
398
399
}
// val in [-4096, 4095], steps of 8, round to nearest
Muyang Li's avatar
Muyang Li committed
400
__device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
401
402
403
404
405
406
407
    // x = max(min(x, 4095), -4096);
    // y = max(min(y, 4095), -4096);
    // TODO: round to even?
    x = x * 8192 + 32768;
    y = y * 8192 + 32768;
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
408
    // ival.lo = x.hi; ival.hi = y.hi;
409
    // <=> divide x and y by 65536 and pack them
Muyang Li's avatar
Muyang Li committed
410
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
411
    // (val & 0x03FF03FF) ^ 0x72007200
Muyang Li's avatar
Muyang Li committed
412
413
414
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
415
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
416
417
}
// val in [-512, 511]
Muyang Li's avatar
Muyang Li committed
418
__device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
419
420
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
421
    // ival.lo = x.lo; ival.hi = y.lo;
422
    // <=> divide x and y by 65536 and pack them
Muyang Li's avatar
Muyang Li committed
423
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
424
    // (val & 0x03FF03FF) ^ 0x66006600
Muyang Li's avatar
Muyang Li committed
425
426
427
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
428
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
429
430
}

Muyang Li's avatar
Muyang Li committed
431
}; // namespace nunchaku::kernels