gemm_utils.cuh 11.7 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#pragma once

#include <cstdint>
#include "common.h"
muyangli's avatar
muyangli committed
5
6
7
#include "../utils.cuh"

namespace nunchaku::kernels {
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

static constexpr int clamp(int val, int min, int max) {
    if (val < min) 
        return min;
    if (val > max)
        return max;
    return val;
}

template<bool shmem = false, typename T>
__device__ __forceinline__
static T load(const T *addr) {
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data;
            asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" : "=r"(data.x), "=r"(data.y) : "l"(__cvta_generic_to_shared(addr)));
            return *reinterpret_cast<T *>(&data);
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data;
            asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(__cvta_generic_to_shared(addr)));
            return *reinterpret_cast<T *>(&data);
        }
        return *addr;
    }

    if constexpr (sizeof(T) == 8) {
        uint2 data = __ldg(reinterpret_cast<const uint2 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = __ldg(reinterpret_cast<const uint4 *>(addr));
        return *reinterpret_cast<T *>(&data);
    }

    return *addr;
}

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
template<typename T>
__device__ __forceinline__
static T load_pred(const T *addr, bool pred) {
    if constexpr (sizeof(T) == 4) {
        uint32_t data;
        asm volatile (
            "{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
            "@loadpred ld.global.nc.b32 %0, [%1];"
            "}" : "=r"(data) : "l"(addr), "r"((int)pred));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data;
        asm volatile (
            "{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
            "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
            "}" : "=r"(data.x), "=r"(data.y) : "l"(addr), "r"((int)pred));
        return *reinterpret_cast<T *>(&data);
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data;
        asm volatile (
            "{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
            "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
            "}" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(addr), "r"((int)pred));
        return *reinterpret_cast<T *>(&data);
    }

    T result;
    if (pred) {
        result = *addr;
    }
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
template<bool shmem = false, typename T>
__device__ __forceinline__
static void store(T *addr, T val) {
    if constexpr (shmem) {
        if constexpr (sizeof(T) == 8) {
            uint2 data = *reinterpret_cast<uint2 *>(&val);
            asm volatile ("st.shared.v2.b32 [%0], {%1, %2};" ::  "l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y));
            return;
        }
        if constexpr (sizeof(T) == 16) {
            uint4 data = *reinterpret_cast<uint4 *>(&val);
            asm volatile ("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
            return;
        }
        *addr = val;
        return;
    }

    if constexpr (sizeof(T) == 4) {
sxtyzhangzk's avatar
sxtyzhangzk committed
100
        __stcg(reinterpret_cast<unsigned int *>(addr), *reinterpret_cast<unsigned int *>(&val));
Zhekai Zhang's avatar
Zhekai Zhang committed
101
102
103
104
105
106
107
108
109
110
111
112
113
        return;
    }
    if constexpr (sizeof(T) == 8) {
        __stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
        return;
    }
    if constexpr (sizeof(T) == 16) {
        __stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
        return;
    } 
    *addr = val;
}

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
template<typename T>
__device__ __forceinline__
static void store_pred(T *addr, T val, bool pred) {
    if constexpr (sizeof(T) == 4) {
        uint32_t data = *reinterpret_cast<uint32_t *>(&val);
        asm volatile (
            "{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
            "@storepred st.global.cg.b32 [%1], %2;"
            "}" :: "r"((int)pred), "l"(addr), "r"(data));
        return;
    }
    if constexpr (sizeof(T) == 8) {
        uint2 data = *reinterpret_cast<uint2 *>(&val);
        asm volatile (
            "{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
            "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
            "}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y));
        return;
    }
    if constexpr (sizeof(T) == 16) {
        uint4 data = *reinterpret_cast<uint4 *>(&val);
        asm volatile (
            "{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
            "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
            "}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
        return;
    }

    if (pred) {
        *addr = val;
    }
}

Zhekai Zhang's avatar
Zhekai Zhang committed
147
__device__ __forceinline__
muyangli's avatar
muyangli committed
148
static float2 half22float2(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
149
150
151
152
    return __half22float2(val);
}

__device__ __forceinline__
muyangli's avatar
muyangli committed
153
static float2 half22float2(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
154
155
156
157
158
    return __bfloat1622float2(val);
}

template<typename T>
__device__ __forceinline__
muyangli's avatar
muyangli committed
159
static T float22half2(float2 val) = delete;
Zhekai Zhang's avatar
Zhekai Zhang committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

template<>
__device__ __forceinline__
half2 float22half2<half2>(float2 val) {
    return __float22half2_rn(val);
}

template<>
__device__ __forceinline__
__nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
    return __float22bfloat162_rn(val);
}

template<typename T>
__device__ __forceinline__
muyangli's avatar
muyangli committed
175
static void unused_var(T &val, bool alwaysfalse) {
Zhekai Zhang's avatar
Zhekai Zhang committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    volatile T *ptr = nullptr;
    if (alwaysfalse) {
        *ptr = val;
    }
}

__device__ __forceinline__ 
static void ldmatrix(const void *ptr, uint4 &out) {
    asm volatile(
        "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
        : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
        : "l"(__cvta_generic_to_shared(ptr))
    );
}

191
192
193
194
195
196
197
template<typename T>
__device__ __forceinline__
static T movmatrix(T x) {
    asm volatile ("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(*reinterpret_cast<uint32_t *>(&x)) : "r"(*reinterpret_cast<uint32_t *>(&x)));
    return x;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
__device__ __forceinline__
uint32_t quantize_float2(float2 value) = delete;

template<>
__device__ __forceinline__
uint32_t quantize_float2<4, false>(float2 value) {
    int v1, v2;
    uint32_t result;
    asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile ("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
    return result;
}

template<>
__device__ __forceinline__
uint32_t quantize_float2<4, true>(float2 value) {
    int v1, v2;
    uint32_t result;
    asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile ("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
    return result;
}

template<>
__device__ __forceinline__
uint32_t quantize_float2<8, false>(float2 value) {
    int v1, v2;
    uint32_t result;
    asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
    asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
    asm volatile ("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
    return result;
}

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
__device__ __forceinline__
uint32_t quantize_float2_fp4(float2 value) {
    uint32_t result;
    asm volatile ("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }" : "=r"(result) : "f"(value.y), "f"(value.x));
    return result;
}

__device__ __forceinline__
uint32_t quantize_float4_fp8(float4 value) {
    uint16_t lo, hi;
    asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
    asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
    return uint32_t(lo) | (uint32_t(hi) << 16);
}

Zhekai Zhang's avatar
Zhekai Zhang committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
__device__ __forceinline__
static float cuda_tanhf(float x) {
    float result;
    asm ("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
    return result;
}

__device__ __forceinline__
static float cuda_frcp(float x) {
    float result;
    asm ("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
    return result;
}

__device__ __forceinline__
static float cuda_frsqrt(float x) {
    float result;
    asm ("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
    return result;
}

__device__ __forceinline__
static float cuda_sin(float x) {
    float result;
    asm ("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
    return result;
}

__device__ __forceinline__
static float cuda_cos(float x) {
    float result;
    asm ("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
    return result;
}

287
288
289
290
291
292
293
__device__ __forceinline__
static float cuda_exp2(float x) {
    float result;
    asm ("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
    return result;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
__forceinline__ __device__ 
static float cuda_sigmoidf (float a)
{
#if USE_TANH
    return fmaf (0.5, __tanhf (0.5f * a), 0.5f);
#else // USE_TANH
    const float L2E = 1.442695041f; // log2(exp(1))
    float t, d, e, r;
    t = -L2E * a;
    asm ("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
    d = e + 1.0f;
    asm ("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
    return r;
#endif // USE_TANH
}

template<typename T>
__device__ __forceinline__ 
static T gelu_half2(T x) {
muyangli's avatar
muyangli committed
314
    float2 xf  = half22float2(x);
Zhekai Zhang's avatar
Zhekai Zhang committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    float2 x3f = xf * xf * xf;
    float t1 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
    float t2 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
    return float22half2<T>(xf * make_float2(t1, t2));
}

template<typename T>
__device__ __forceinline__ 
static T gelu_half(T x) {
    float xf  = float(x);
    float x3f = xf * xf * xf;
    float t = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
    return (T)(xf * t);
}

template <typename T>
__device__ __forceinline__ 
static T silu(const T &x) {
  // x * sigmoid(x)
  return (T)((float)x * cuda_sigmoidf((float)x));
  // return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
}

muyangli's avatar
muyangli committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
__device__ __forceinline__
static half2 h2div(half2 a, half2 b)  {
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<half2>(of);
};
__device__ __forceinline__
static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b)  {
    float2 af = half22float2(a);
    float2 bf = half22float2(b);
    float2 of;
    of.x = __fdividef(af.x, bf.x);
    of.y = __fdividef(af.y, bf.y);
    return float22half2<__nv_bfloat162>(of);
};

Zhekai Zhang's avatar
Zhekai Zhang committed
357
358
359
360
361
362
363
364
365
366
367
368
__device__ __forceinline__
static void reduce_add(float *addr, float val) {
    asm volatile ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "f"(val));
}

template<int cnt, typename F>
__device__ __forceinline__
static void unrolled_loop(F &&lambda) {
    auto call = [&]<int ...Is>(std::integer_sequence<int, Is...>) {
        (lambda.template operator()<Is>(), ...);
    };
    call(std::make_integer_sequence<int, cnt>());
muyangli's avatar
muyangli committed
369
370
}

371
372
373
374
375
376
377
378
379
380
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__ __forceinline__
static float int2float_fast(int val) {
    float fval;
    // fval = (val & 0x7FFFFF) ^ 0x4B400000
    asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=f"(fval) : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
    return fval - 12582912.0f;
}

381
382
383
384
385
386
387
388
template<typename To, typename From>
__device__ __forceinline__
static To bit_cast(const From &input) {
    static_assert(sizeof(To) == sizeof(From));
    // not safe but anyway
    return *reinterpret_cast<const To *>(&input);
}

muyangli's avatar
muyangli committed
389
};  // namespace nunchaku::kernels