gemm_utils.cuh 14.9 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#pragma once

#include <cstdint>
#include "common.h"
muyangli's avatar
muyangli committed
5
6
7
#include "../utils.cuh"

namespace nunchaku::kernels {
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9

static constexpr int clamp(int val, int min, int max) {
Muyang Li's avatar
Muyang Li committed
10
    if (val < min)
Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
14
15
16
17
        return min;
    if (val > max)
        return max;
    return val;
}

template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
18
__device__ __forceinline__ static T load(const T *addr) {
Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data;
Muyang Li's avatar
Muyang Li committed
22
23
            asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];"
                         : "=r"(data.x), "=r"(data.y)
fengzch's avatar
fengzch committed
24
                         : "l"((addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
28
            return *reinterpret_cast<T *>(&data);
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data;
Muyang Li's avatar
Muyang Li committed
29
30
            asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
                         : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
fengzch's avatar
fengzch committed
31
                         : "l"((addr)));
Zhekai Zhang's avatar
Zhekai Zhang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            return *reinterpret_cast<T *>(&data);
        }
        return *addr;
    }

    if constexpr (sizeof(T) == 8) {
        uint2 data = __ldg(reinterpret_cast<const uint2 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = __ldg(reinterpret_cast<const uint4 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }

    return *addr;
}

49
template<typename T>
Muyang Li's avatar
Muyang Li committed
50
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
51
52
    if constexpr (sizeof(T) == 4) {
        uint32_t data;
Muyang Li's avatar
Muyang Li committed
53
54
55
56
57
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
                     "@loadpred ld.global.nc.b32 %0, [%1];"
                     "}"
                     : "=r"(data)
                     : "l"(addr), "r"((int)pred));
58
59
60
61
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data;
Muyang Li's avatar
Muyang Li committed
62
63
64
65
66
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
                     "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
                     "}"
                     : "=r"(data.x), "=r"(data.y)
                     : "l"(addr), "r"((int)pred));
67
68
69
70
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data;
Muyang Li's avatar
Muyang Li committed
71
72
73
74
75
        asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
                     "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
                     "}"
                     : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
                     : "l"(addr), "r"((int)pred));
76
77
78
79
80
81
82
83
84
85
        return *reinterpret_cast<T *>(&data);
    }

    T result;
    if (pred) {
        result = *addr;
    }
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
86
template<bool shmem = false, typename T>
Muyang Li's avatar
Muyang Li committed
87
__device__ __forceinline__ static void store(T *addr, T val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
88
89
90
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data = *reinterpret_cast<uint2 *>(&val);
Muyang Li's avatar
Muyang Li committed
91
            asm volatile(
fengzch's avatar
fengzch committed
92
                "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
Zhekai Zhang's avatar
Zhekai Zhang committed
93
94
95
96
            return;
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data = *reinterpret_cast<uint4 *>(&val);
fengzch's avatar
fengzch committed
97
            asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
Muyang Li's avatar
Muyang Li committed
98
99
100
101
                         "r"(data.x),
                         "r"(data.y),
                         "r"(data.z),
                         "r"(data.w));
Zhekai Zhang's avatar
Zhekai Zhang committed
102
103
104
105
106
107
108
            return;
        }
        *addr = val;
        return;
    }

    if constexpr (sizeof(T) == 4) {
limm's avatar
limm committed
109
110
        // __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
	*reinterpret_cast<unsigned int *>(addr) = *reinterpret_cast<unsigned int *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
111
112
113
        return;
    }
    if constexpr (sizeof(T) == 8) {
limm's avatar
limm committed
114
115
        // __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
	*reinterpret_cast<uint2 *>(addr) = *reinterpret_cast<uint2 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
116
117
118
        return;
    }
    if constexpr (sizeof(T) == 16) {
limm's avatar
limm committed
119
120
        // __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
	*reinterpret_cast<uint4 *>(addr) = *reinterpret_cast<uint4 *>(&val);
Zhekai Zhang's avatar
Zhekai Zhang committed
121
        return;
Muyang Li's avatar
Muyang Li committed
122
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
123
124
125
    *addr = val;
}

126
template<typename T>
Muyang Li's avatar
Muyang Li committed
127
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
128
129
    if constexpr (sizeof(T) == 4) {
        uint32_t data = *reinterpret_cast<uint32_t *>(&val);
Muyang Li's avatar
Muyang Li committed
130
131
132
133
134
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.b32 [%1], %2;"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data));
135
136
137
138
        return;
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data = *reinterpret_cast<uint2 *>(&val);
Muyang Li's avatar
Muyang Li committed
139
140
141
142
143
144
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y));
145
146
147
148
        return;
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = *reinterpret_cast<uint4 *>(&val);
Muyang Li's avatar
Muyang Li committed
149
150
151
152
153
154
155
156
        asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                     "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
                     "}" ::"r"((int)pred),
                     "l"(addr),
                     "r"(data.x),
                     "r"(data.y),
                     "r"(data.z),
                     "r"(data.w));
157
158
159
160
161
162
163
164
        return;
    }

    if (pred) {
        *addr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
165
__device__ __forceinline__ static float2 half22float2(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
166
167
168
    return __half22float2(val);
}

fengzch-das's avatar
fengzch-das committed
169
__device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
170
171
172
173
    return __bfloat1622float2(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
174
__device__ __forceinline__ static T float22half2(float2 val) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
175
176

template<>
Muyang Li's avatar
Muyang Li committed
177
__device__ __forceinline__ half2 float22half2<half2>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
178
179
180
181
    return __float22half2_rn(val);
}

template<>
fengzch-das's avatar
fengzch-das committed
182
__device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
183
184
185
186
    return __float22bfloat162_rn(val);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
187
__device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
Zhekai Zhang's avatar
Zhekai Zhang committed
188
189
190
191
192
193
    volatile T *ptr = nullptr;
    if (alwaysfalse) {
        *ptr = val;
    }
}

Muyang Li's avatar
Muyang Li committed
194
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
fengzch's avatar
fengzch committed
195
196
197
    asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
                 : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
                 : "l"((ptr)));  // limengmeng
Zhekai Zhang's avatar
Zhekai Zhang committed
198
199
}

200
template<typename T>
Muyang Li's avatar
Muyang Li committed
201
__device__ __forceinline__ static T movmatrix(T x) {
202
203
204
205
    // asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
    //              : "=r"(*reinterpret_cast<uint32_t *>(&x))
    //              : "r"(*reinterpret_cast<uint32_t *>(&x)));
    printf("%s: asm movmatrix.sync.aligned.m8n8.trans.b16 is not supported in HIP yet!\n", __func__);
206
207
208
    return x;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
209
210
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
Muyang Li's avatar
Muyang Li committed
211
__device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
212
213

template<>
Muyang Li's avatar
Muyang Li committed
214
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
215
216
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
217
218
219
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
220
221
222
223
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
224
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
225
226
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
227
228
229
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
230
231
232
233
    return result;
}

template<>
Muyang Li's avatar
Muyang Li committed
234
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
Zhekai Zhang's avatar
Zhekai Zhang committed
235
236
    int v1, v2;
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
237
238
239
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
Zhekai Zhang's avatar
Zhekai Zhang committed
240
241
242
    return result;
}

Muyang Li's avatar
Muyang Li committed
243
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
244
    uint32_t result;
Muyang Li's avatar
Muyang Li committed
245
246
247
    asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
                 : "=r"(result)
                 : "f"(value.y), "f"(value.x));
248
249
250
    return result;
}

Muyang Li's avatar
Muyang Li committed
251
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
252
    uint16_t lo, hi;
Muyang Li's avatar
Muyang Li committed
253
254
    asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
    asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
255
256
257
    return uint32_t(lo) | (uint32_t(hi) << 16);
}

Muyang Li's avatar
Muyang Li committed
258
__device__ __forceinline__ static float cuda_tanhf(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
259
    float result;
Muyang Li's avatar
Muyang Li committed
260
    asm("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
261
262
263
    return result;
}

Muyang Li's avatar
Muyang Li committed
264
__device__ __forceinline__ static float cuda_frcp(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
265
    float result;
Muyang Li's avatar
Muyang Li committed
266
    asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
267
268
269
    return result;
}

Muyang Li's avatar
Muyang Li committed
270
__device__ __forceinline__ static float cuda_frsqrt(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
271
    float result;
Muyang Li's avatar
Muyang Li committed
272
    asm("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
273
274
275
    return result;
}

Muyang Li's avatar
Muyang Li committed
276
__device__ __forceinline__ static float cuda_sin(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
277
    float result;
Muyang Li's avatar
Muyang Li committed
278
    asm("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
279
280
281
    return result;
}

Muyang Li's avatar
Muyang Li committed
282
__device__ __forceinline__ static float cuda_cos(float x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
283
    float result;
Muyang Li's avatar
Muyang Li committed
284
    asm("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
285
286
287
    return result;
}

Muyang Li's avatar
Muyang Li committed
288
__device__ __forceinline__ static float cuda_exp2(float x) {
289
    float result;
Muyang Li's avatar
Muyang Li committed
290
    asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
291
292
293
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
294
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
Muyang Li's avatar
Muyang Li committed
295
__forceinline__ __device__ static float cuda_sigmoidf(float a) {
Zhekai Zhang's avatar
Zhekai Zhang committed
296
#if USE_TANH
Muyang Li's avatar
Muyang Li committed
297
298
    return fmaf(0.5, __tanhf(0.5f * a), 0.5f);
#else  // USE_TANH
Zhekai Zhang's avatar
Zhekai Zhang committed
299
300
301
    const float L2E = 1.442695041f; // log2(exp(1))
    float t, d, e, r;
    t = -L2E * a;
Muyang Li's avatar
Muyang Li committed
302
    asm("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
Zhekai Zhang's avatar
Zhekai Zhang committed
303
    d = e + 1.0f;
Muyang Li's avatar
Muyang Li committed
304
    asm("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
Zhekai Zhang's avatar
Zhekai Zhang committed
305
306
307
308
309
    return r;
#endif // USE_TANH
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
310
__device__ __forceinline__ static T gelu_half2(T x) {
muyangli's avatar
muyangli committed
311
    float2 xf  = half22float2(x);
Zhekai Zhang's avatar
Zhekai Zhang committed
312
    float2 x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
313
314
    float t1   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
    float t2   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
Zhekai Zhang's avatar
Zhekai Zhang committed
315
316
317
318
    return float22half2<T>(xf * make_float2(t1, t2));
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
319
__device__ __forceinline__ static T gelu_half(T x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
320
321
    float xf  = float(x);
    float x3f = xf * xf * xf;
Muyang Li's avatar
Muyang Li committed
322
    float t   = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
Zhekai Zhang's avatar
Zhekai Zhang committed
323
324
325
    return (T)(xf * t);
}

Muyang Li's avatar
Muyang Li committed
326
327
328
329
330
template<typename T>
__device__ __forceinline__ static T silu(const T &x) {
    // x * sigmoid(x)
    return (T)((float)x * cuda_sigmoidf((float)x));
    // return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
Zhekai Zhang's avatar
Zhekai Zhang committed
331
332
}

Muyang Li's avatar
Muyang Li committed
333
__device__ __forceinline__ static half2 h2div(half2 a, half2 b) {
muyangli's avatar
muyangli committed
334
335
336
337
338
339
340
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<half2>(of);
};
fengzch-das's avatar
fengzch-das committed
341
__device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
muyangli's avatar
muyangli committed
342
343
344
345
346
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
fengzch-das's avatar
fengzch-das committed
347
    return float22half2<__nv_bfloat162>(of);
muyangli's avatar
muyangli committed
348
349
};

Muyang Li's avatar
Muyang Li committed
350
351
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
    asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
Zhekai Zhang's avatar
Zhekai Zhang committed
352
353
}

Muyang Li's avatar
Muyang Li committed
354
355
356
357
358
359
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
    asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
                 "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
                 "}" ::"r"((int)pred),
                 "l"(addr),
                 "f"(val));
sxtyzhangzk's avatar
sxtyzhangzk committed
360
361
}

Zhekai Zhang's avatar
Zhekai Zhang committed
362
template<int cnt, typename F>
Muyang Li's avatar
Muyang Li committed
363
364
__device__ __forceinline__ static void unrolled_loop(F &&lambda) {
    auto call = [&]<int... Is>(std::integer_sequence<int, Is...>) { (lambda.template operator()<Is>(), ...); };
Zhekai Zhang's avatar
Zhekai Zhang committed
365
    call(std::make_integer_sequence<int, cnt>());
muyangli's avatar
muyangli committed
366
367
}

368
369
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
Muyang Li's avatar
Muyang Li committed
370
__device__ __forceinline__ static float int2float_fast(int val) {
371
372
    float fval;
    // fval = (val & 0x7FFFFF) ^ 0x4B400000
Muyang Li's avatar
Muyang Li committed
373
374
375
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=f"(fval)
                 : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
376
377
378
    return fval - 12582912.0f;
}

379
template<typename To, typename From>
Muyang Li's avatar
Muyang Li committed
380
__device__ __forceinline__ static To bit_cast(const From &input) {
381
382
383
384
385
    static_assert(sizeof(To) == sizeof(From));
    // not safe but anyway
    return *reinterpret_cast<const To *>(&input);
}

386
387
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
Muyang Li's avatar
Muyang Li committed
388
__device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
389
390
391
    uint32_t ival;
    uint32_t hval;
    // ival.lo = x.lo; ival.hi = y.lo;
Muyang Li's avatar
Muyang Li committed
392
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
393
394
    ival = ival >> 4;
    // (val & 0x03FF03FF) ^ 0x76007600
Muyang Li's avatar
Muyang Li committed
395
396
397
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
398
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
399
400
}
// val in [-4096, 4095], steps of 8, round to nearest
Muyang Li's avatar
Muyang Li committed
401
__device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
402
403
404
405
406
407
408
    // x = max(min(x, 4095), -4096);
    // y = max(min(y, 4095), -4096);
    // TODO: round to even?
    x = x * 8192 + 32768;
    y = y * 8192 + 32768;
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
409
    // ival.lo = x.hi; ival.hi = y.hi;
410
    // <=> divide x and y by 65536 and pack them
Muyang Li's avatar
Muyang Li committed
411
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
412
    // (val & 0x03FF03FF) ^ 0x72007200
Muyang Li's avatar
Muyang Li committed
413
414
415
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
416
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
417
418
}
// val in [-512, 511]
Muyang Li's avatar
Muyang Li committed
419
__device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
420
421
    uint32_t ival;
    uint32_t hval;
Muyang Li's avatar
Muyang Li committed
422
    // ival.lo = x.lo; ival.hi = y.lo;
423
    // <=> divide x and y by 65536 and pack them
Muyang Li's avatar
Muyang Li committed
424
    asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
425
    // (val & 0x03FF03FF) ^ 0x66006600
Muyang Li's avatar
Muyang Li committed
426
427
428
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
                 : "=r"(hval)
                 : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
sxtyzhangzk's avatar
sxtyzhangzk committed
429
    return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
430
431
}

Muyang Li's avatar
Muyang Li committed
432
}; // namespace nunchaku::kernels