#pragma once #include #include "common.h" #include "../utils.cuh" namespace nunchaku::kernels { static constexpr int clamp(int val, int min, int max) { if (val < min) return min; if (val > max) return max; return val; } template __device__ __forceinline__ static T load(const T *addr) { if constexpr (shmem) { if constexpr (sizeof(T) == 8) { uint2 data; asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" : "=r"(data.x), "=r"(data.y) : "l"(__cvta_generic_to_shared(addr))); return *reinterpret_cast(&data); } if constexpr (sizeof(T) == 16) { uint4 data; asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(__cvta_generic_to_shared(addr))); return *reinterpret_cast(&data); } return *addr; } if constexpr (sizeof(T) == 8) { uint2 data = __ldg(reinterpret_cast(addr)); return *reinterpret_cast(&data); } if constexpr (sizeof(T) == 16) { uint4 data = __ldg(reinterpret_cast(addr)); return *reinterpret_cast(&data); } return *addr; } template __device__ __forceinline__ static T load_pred(const T *addr, bool pred) { if constexpr (sizeof(T) == 4) { uint32_t data; asm volatile ( "{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;" "@loadpred ld.global.nc.b32 %0, [%1];" "}" : "=r"(data) : "l"(addr), "r"((int)pred)); return *reinterpret_cast(&data); } if constexpr (sizeof(T) == 8) { uint2 data; asm volatile ( "{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;" "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];" "}" : "=r"(data.x), "=r"(data.y) : "l"(addr), "r"((int)pred)); return *reinterpret_cast(&data); } if constexpr (sizeof(T) == 16) { uint4 data; asm volatile ( "{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;" "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];" "}" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(addr), "r"((int)pred)); return *reinterpret_cast(&data); } T result; if (pred) { result = *addr; } return result; } template __device__ __forceinline__ static void store(T *addr, T val) { if constexpr (shmem) { if constexpr (sizeof(T) == 8) { uint2 data = *reinterpret_cast(&val); asm volatile ("st.shared.v2.b32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y)); return; } if constexpr (sizeof(T) == 16) { uint4 data = *reinterpret_cast(&val); asm volatile ("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); return; } *addr = val; return; } if constexpr (sizeof(T) == 4) { __stcg(reinterpret_cast(addr), *reinterpret_cast(&val)); return; } if constexpr (sizeof(T) == 8) { __stcg(reinterpret_cast(addr), *reinterpret_cast(&val)); return; } if constexpr (sizeof(T) == 16) { __stcg(reinterpret_cast(addr), *reinterpret_cast(&val)); return; } *addr = val; } template __device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) { if constexpr (sizeof(T) == 4) { uint32_t data = *reinterpret_cast(&val); asm volatile ( "{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" "@storepred st.global.cg.b32 [%1], %2;" "}" :: "r"((int)pred), "l"(addr), "r"(data)); return; } if constexpr (sizeof(T) == 8) { uint2 data = *reinterpret_cast(&val); asm volatile ( "{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" "@storepred st.global.cg.v2.b32 [%1], {%2, %3};" "}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y)); return; } if constexpr (sizeof(T) == 16) { uint4 data = *reinterpret_cast(&val); asm volatile ( "{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};" "}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); return; } if (pred) { *addr = val; } } __device__ __forceinline__ static float2 half22float2(half2 val) { return __half22float2(val); } __device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) { return __bfloat1622float2(val); } template __device__ __forceinline__ static T float22half2(float2 val) = delete; template<> __device__ __forceinline__ half2 float22half2(float2 val) { return __float22half2_rn(val); } template<> __device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) { return __float22bfloat162_rn(val); } template __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) { volatile T *ptr = nullptr; if (alwaysfalse) { *ptr = val; } } __device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) { asm volatile( "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w) : "l"(__cvta_generic_to_shared(ptr)) ); } template __device__ __forceinline__ static T movmatrix(T x) { asm volatile ("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(*reinterpret_cast(&x)) : "r"(*reinterpret_cast(&x))); return x; } namespace mma_helper { struct f32 { static constexpr const char value[] = "f32"; }; struct f16 { static constexpr const char value[] = "f16"; }; struct bf16 { static constexpr const char value[] = "bf16"; }; struct s32 { static constexpr const char value[] = "s32"; }; struct s4 { static constexpr const char value[] = "s4"; }; struct u4 { static constexpr const char value[] = "u4"; }; template using f16bf16 = std::conditional_t; template using s4u4 = std::conditional_t; }; __device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) { uint2 d; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, %9};\n" : "=r"(d.x), "=r"(d.y) : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y) ); #else asm volatile( "{" ".reg .b32 tmp0, tmp1;" "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{tmp0, tmp1}," "{%2, %3}," "{%6}," "{%8, %9};\n" "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%4, %5}," "{%7}," "{tmp0, tmp1};" "}\n" : "=r"(d.x), "=r"(d.y) : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y) ); #endif return d; } template __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) { uint4 d; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "C"(mma_helper::f16bf16::value) ); #else static_assert(!is_bf16); asm volatile( "{" ".reg .b32 tmp0, tmp1, tmp2, tmp3;" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{tmp0, tmp1, tmp2, tmp3}," "{%4, %5}," "{%8}," "{%10, %11, %12, %13};\n" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%6, %7}," "{%9}," "{tmp0, tmp1, tmp2, tmp3};" "}\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w) ); #endif return d; } template __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) { uint4 d; static constexpr int K = (std::is_same_v || std::is_same_v) ? 64 : 32; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile( "mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K) "C"(AType::value), "C"(BType::value) ); #else asm volatile( "{" ".reg .b32 tmp0, tmp1, tmp2, tmp3;" "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " "{tmp0, tmp1}," "{%4}," "{%8}," "{%10, %11};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " "{tmp2, tmp3}," "{%5}," "{%8}," "{%12, %13};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " "{%0, %1}," "{%6}," "{%9}," "{tmp0, tmp1};\n" "mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 " "{%2, %3}," "{%7}," "{%9}," "{tmp2, tmp3};\n" "}\n" : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w) : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K / 2) "C"(AType::value), "C"(BType::value) ); #endif return d; } // x in low bit, y in high bit template __device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete; template<> __device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) { int v1, v2; uint32_t result; asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); asm volatile ("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); return result; } template<> __device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) { int v1, v2; uint32_t result; asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); asm volatile ("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); return result; } template<> __device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) { int v1, v2; uint32_t result; asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); asm volatile ("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); return result; } __device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) { uint32_t result; asm volatile ("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }" : "=r"(result) : "f"(value.y), "f"(value.x)); return result; } __device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) { uint16_t lo, hi; asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x)); asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z)); return uint32_t(lo) | (uint32_t(hi) << 16); } __device__ __forceinline__ static float cuda_tanhf(float x) { float result; asm ("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_frcp(float x) { float result; asm ("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_frsqrt(float x) { float result; asm ("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_sin(float x) { float result; asm ("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_cos(float x) { float result; asm ("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_exp2(float x) { float result; asm ("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } // https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206 __forceinline__ __device__ static float cuda_sigmoidf (float a) { #if USE_TANH return fmaf (0.5, __tanhf (0.5f * a), 0.5f); #else // USE_TANH const float L2E = 1.442695041f; // log2(exp(1)) float t, d, e, r; t = -L2E * a; asm ("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t)); d = e + 1.0f; asm ("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d)); return r; #endif // USE_TANH } template __device__ __forceinline__ static T gelu_half2(T x) { float2 xf = half22float2(x); float2 x3f = xf * xf * xf; float t1 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x))); float t2 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y))); return float22half2(xf * make_float2(t1, t2)); } template __device__ __forceinline__ static T gelu_half(T x) { float xf = float(x); float x3f = xf * xf * xf; float t = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f))); return (T)(xf * t); } template __device__ __forceinline__ static T silu(const T &x) { // x * sigmoid(x) return (T)((float)x * cuda_sigmoidf((float)x)); // return (T)__fdividef((float)x, 1.0f + __expf((float)-x)); } __device__ __forceinline__ static half2 h2div(half2 a, half2 b) { float2 af = half22float2(a); float2 bf = half22float2(b); float2 of; of.x = __fdividef(af.x, bf.x); of.y = __fdividef(af.y, bf.y); return float22half2(of); }; __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) { float2 af = half22float2(a); float2 bf = half22float2(b); float2 of; of.x = __fdividef(af.x, bf.x); of.y = __fdividef(af.y, bf.y); return float22half2<__nv_bfloat162>(of); }; __device__ __forceinline__ static void reduce_add(float *addr, float val) { asm volatile ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "f"(val)); } template __device__ __forceinline__ static void unrolled_loop(F &&lambda) { auto call = [&](std::integer_sequence) { (lambda.template operator()(), ...); }; call(std::make_integer_sequence()); } // int2float is slow on sm_80 and before // val in [-4194304, 4194303] __device__ __forceinline__ static float int2float_fast(int val) { float fval; // fval = (val & 0x7FFFFF) ^ 0x4B400000 asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=f"(fval) : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA)); return fval - 12582912.0f; } template __device__ __forceinline__ static To bit_cast(const From &input) { static_assert(sizeof(To) == sizeof(From)); // not safe but anyway return *reinterpret_cast(&input); } // both int2float and float2half are slow on sm_75 and before // val in [-8192, 8191], steps of 16, round to negative inf __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) { uint32_t ival; uint32_t hval; // ival.lo = x.lo; ival.hi = y.lo; asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); ival = ival >> 4; // (val & 0x03FF03FF) ^ 0x76007600 asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA)); return __hadd2(kernels::bit_cast(hval), half2(-24576.0f, -24576.0f)); } // val in [-4096, 4095], steps of 8, round to nearest __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) { // x = max(min(x, 4095), -4096); // y = max(min(y, 4095), -4096); // TODO: round to even? x = x * 8192 + 32768; y = y * 8192 + 32768; uint32_t ival; uint32_t hval; // ival.lo = x.hi; ival.hi = y.hi; // <=> divide x and y by 65536 and pack them asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632)); // (val & 0x03FF03FF) ^ 0x72007200 asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA)); return __hadd2(kernels::bit_cast(hval), half2(-12288.0f, -12288.0f)); } // val in [-512, 511] __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) { uint32_t ival; uint32_t hval; // ival.lo = x.lo; ival.hi = y.lo; // <=> divide x and y by 65536 and pack them asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); // (val & 0x03FF03FF) ^ 0x66006600 asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA)); return __hadd2(kernels::bit_cast(hval), half2(-1536.0f, -1536.0f)); } }; // namespace nunchaku::kernels