#pragma once #include #include "common.h" #include "../utils.cuh" namespace nunchaku::kernels { static constexpr int clamp(int val, int min, int max) { if (val < min) return min; if (val > max) return max; return val; } template __device__ __forceinline__ static T load(const T *addr) { if constexpr (shmem) { if constexpr (sizeof(T) == 8) { uint2 data; asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];" : "=r"(data.x), "=r"(data.y) : "l"((addr))); return *reinterpret_cast(&data); } if constexpr (sizeof(T) == 16) { uint4 data; asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"((addr))); return *reinterpret_cast(&data); } return *addr; } if constexpr (sizeof(T) == 8) { uint2 data = __ldg(reinterpret_cast(addr)); return *reinterpret_cast(&data); } if constexpr (sizeof(T) == 16) { uint4 data = __ldg(reinterpret_cast(addr)); return *reinterpret_cast(&data); } return *addr; } // template // __device__ __forceinline__ static T load_pred(const T *addr, bool pred) { // if constexpr (sizeof(T) == 4) { // uint32_t data; // // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;" // // "@loadpred ld.global.nc.b32 %0, [%1];" // // "}" // // : "=r"(data) // // : "l"(addr), "r"((int)pred)); // return *reinterpret_cast(&data); // } // if constexpr (sizeof(T) == 8) { // uint2 data; // // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;" // // "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];" // // "}" // // : "=r"(data.x), "=r"(data.y) // // : "l"(addr), "r"((int)pred)); // return *reinterpret_cast(&data); // } // if constexpr (sizeof(T) == 16) { // uint4 data; // // asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;" // // "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];" // // "}" // // : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) // // : "l"(addr), "r"((int)pred)); // return *reinterpret_cast(&data); // } // T result; // if (pred) { // result = *addr; // } // return result; // } template __device__ __forceinline__ static T load_pred(const T *addr, bool pred) { if constexpr (sizeof(T) == 4) { uint32_t data; if (pred) { const unsigned char *src = reinterpret_cast(addr); unsigned char *dst = reinterpret_cast(&data); #pragma unroll for (int i = 0; i < 4; ++i) dst[i] = src[i]; } return *reinterpret_cast(&data); } if constexpr (sizeof(T) == 8) { uint2 data; if (pred) { const unsigned char *src = reinterpret_cast(addr); unsigned char *dst = reinterpret_cast(&data); #pragma unroll for (int i = 0; i < 8; ++i) dst[i] = src[i]; } return *reinterpret_cast(&data); } if constexpr (sizeof(T) == 16) { uint4 data; if (pred) { const unsigned char *src = reinterpret_cast(addr); unsigned char *dst = reinterpret_cast(&data); #pragma unroll for (int i = 0; i < 16; ++i) dst[i] = src[i]; } return *reinterpret_cast(&data); } T result; if (pred) { result = *addr; } return result; } template __device__ __forceinline__ static void store(T *addr, T val) { if constexpr (shmem) { if constexpr (sizeof(T) == 8) { uint2 data = *reinterpret_cast(&val); asm volatile( "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y)); return; } if constexpr (sizeof(T) == 16) { uint4 data = *reinterpret_cast(&val); asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); return; } *addr = val; return; } if constexpr (sizeof(T) == 4) { // __stcg(reinterpret_cast(addr), *reinterpret_cast(&val)); *reinterpret_cast(addr) = *reinterpret_cast(&val); return; } if constexpr (sizeof(T) == 8) { // __stcg(reinterpret_cast(addr), *reinterpret_cast(&val)); *reinterpret_cast(addr) = *reinterpret_cast(&val); return; } if constexpr (sizeof(T) == 16) { // __stcg(reinterpret_cast(addr), *reinterpret_cast(&val)); *reinterpret_cast(addr) = *reinterpret_cast(&val); return; } *addr = val; } template __device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) { if constexpr (sizeof(T) == 4) { uint32_t data = *reinterpret_cast(&val); asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" "@storepred st.global.cg.b32 [%1], %2;" "}" ::"r"((int)pred), "l"(addr), "r"(data)); return; } if constexpr (sizeof(T) == 8) { uint2 data = *reinterpret_cast(&val); asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" "@storepred st.global.cg.v2.b32 [%1], {%2, %3};" "}" ::"r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y)); return; } if constexpr (sizeof(T) == 16) { uint4 data = *reinterpret_cast(&val); asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};" "}" ::"r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); return; } if (pred) { *addr = val; } } __device__ __forceinline__ static float2 half22float2(half2 val) { return __half22float2(val); } __device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) { return __bfloat1622float2(val); } template __device__ __forceinline__ static T float22half2(float2 val) = delete; template<> __device__ __forceinline__ half2 float22half2(float2 val) { return __float22half2_rn(val); } template<> __device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) { return __float22bfloat162_rn(val); } template __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) { volatile T *ptr = nullptr; if (alwaysfalse) { *ptr = val; } } __device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) { asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w) : "l"((ptr))); } template __device__ __forceinline__ static T movmatrix(T x) { // asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" // : "=r"(*reinterpret_cast(&x)) // : "r"(*reinterpret_cast(&x))); return x; } // x in low bit, y in high bit template __device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete; template<> __device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) { int v1, v2; uint32_t result; // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); // asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); v1 = __float2int_rn(value.x); v2 = __float2int_rn(value.y); int s1 = max(-8, min(7, v1)); int s2 = max(-8, min(7, v2)); unsigned int u1 = s1 & 0xF; unsigned int u2 = s2 & 0xF; result = (u2 << 4) | u1; return result; } template<> __device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) { int v1, v2; uint32_t result; // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); // asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); v1 = __float2int_rn(value.x); v2 = __float2int_rn(value.y); unsigned int u1 = static_cast(max(0, min(15, v1))); unsigned int u2 = static_cast(max(0, min(15, v2))); result = (u2 << 4) | u1; return result; } template<> __device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) { int v1, v2; uint32_t result; // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x)); // asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y)); // asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1)); v1 = __float2int_rn(value.x); // 等价于 roundf(value.x) v2 = __float2int_rn(value.y); // 第二步:饱和处理到8位有符号范围 [-128, 127] int s1 = max(-128, min(127, v1)); int s2 = max(-128, min(127, v2)); // 第三步:将有符号值转换为无符号位模式 // 使用位运算将有符号数转换为8位二进制补码表示 unsigned int u1 = s1 & 0xFF; // 只取低8位 unsigned int u2 = s2 & 0xFF; result = (u2 << 8) | u1; return result; } __device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) { uint32_t result; asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }" : "=r"(result) : "f"(value.y), "f"(value.x)); return result; } __device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) { uint16_t lo, hi; // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x)); // asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z)); return uint32_t(lo) | (uint32_t(hi) << 16); } __device__ __forceinline__ static float cuda_tanhf(float x) { float result; asm("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_frcp(float x) { float result; asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_frsqrt(float x) { float result; asm("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_sin(float x) { float result; asm("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_cos(float x) { float result; asm("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } __device__ __forceinline__ static float cuda_exp2(float x) { float result; asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x)); return result; } // https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206 __forceinline__ __device__ static float cuda_sigmoidf(float a) { #if USE_TANH return fmaf(0.5, __tanhf(0.5f * a), 0.5f); #else // USE_TANH const float L2E = 1.442695041f; // log2(exp(1)) float t, d, e, r; t = -L2E * a; asm("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t)); d = e + 1.0f; asm("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d)); return r; #endif // USE_TANH } template __device__ __forceinline__ static T gelu_half2(T x) { float2 xf = half22float2(x); float2 x3f = xf * xf * xf; float t1 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x))); float t2 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y))); return float22half2(xf * make_float2(t1, t2)); } template __device__ __forceinline__ static T gelu_half(T x) { float xf = float(x); float x3f = xf * xf * xf; float t = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f))); return (T)(xf * t); } template __device__ __forceinline__ static T silu(const T &x) { // x * sigmoid(x) return (T)((float)x * cuda_sigmoidf((float)x)); // return (T)__fdividef((float)x, 1.0f + __expf((float)-x)); } __device__ __forceinline__ static half2 h2div(half2 a, half2 b) { float2 af = half22float2(a); float2 bf = half22float2(b); float2 of; of.x = __fdividef(af.x, bf.x); of.y = __fdividef(af.y, bf.y); return float22half2(of); }; __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) { float2 af = half22float2(a); float2 bf = half22float2(b); float2 of; of.x = __fdividef(af.x, bf.x); of.y = __fdividef(af.y, bf.y); return float22half2<__nv_bfloat162>(of); }; __device__ __forceinline__ static void reduce_add(float *addr, float val) { asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val)); } __device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) { asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;" "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;" "}" ::"r"((int)pred), "l"(addr), "f"(val)); } template __device__ __forceinline__ static void unrolled_loop(F &&lambda) { auto call = [&](std::integer_sequence) { (lambda.template operator()(), ...); }; call(std::make_integer_sequence()); } // int2float is slow on sm_80 and before // val in [-4194304, 4194303] __device__ __forceinline__ static float int2float_fast(int val) { // float fval; // // fval = (val & 0x7FFFFF) ^ 0x4B400000 // asm volatile("lop3.b32 %0, %1, %2, %3, %4;" // : "=f"(fval) // : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA)); unsigned int temp = (val & 0x7FFFFF) ^ 0x4B400000; float result; memcpy(&result, &temp, sizeof(float)); return result - 12582912.0f; } template __device__ __forceinline__ static To bit_cast(const From &input) { static_assert(sizeof(To) == sizeof(From)); // not safe but anyway return *reinterpret_cast(&input); } // both int2float and float2half are slow on sm_75 and before // val in [-8192, 8191], steps of 16, round to negative inf __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) { uint32_t ival; uint32_t hval; // ival.lo = x.lo; ival.hi = y.lo; asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); ival = ival >> 4; // (val & 0x03FF03FF) ^ 0x76007600 asm volatile("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA)); return __hadd2(kernels::bit_cast(hval), half2(-24576.0f, -24576.0f)); } // val in [-4096, 4095], steps of 8, round to nearest __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) { // x = max(min(x, 4095), -4096); // y = max(min(y, 4095), -4096); // TODO: round to even? x = x * 8192 + 32768; y = y * 8192 + 32768; uint32_t ival; uint32_t hval; // ival.lo = x.hi; ival.hi = y.hi; // <=> divide x and y by 65536 and pack them asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632)); // (val & 0x03FF03FF) ^ 0x72007200 asm volatile("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA)); return __hadd2(kernels::bit_cast(hval), half2(-12288.0f, -12288.0f)); } // val in [-512, 511] __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) { uint32_t ival; uint32_t hval; // ival.lo = x.lo; ival.hi = y.lo; // <=> divide x and y by 65536 and pack them asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410)); // (val & 0x03FF03FF) ^ 0x66006600 asm volatile("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA)); return __hadd2(kernels::bit_cast(hval), half2(-1536.0f, -1536.0f)); } }; // namespace nunchaku::kernels