Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
......@@ -7,7 +7,7 @@
namespace nunchaku::kernels {
static constexpr int clamp(int val, int min, int max) {
if (val < min)
if (val < min)
return min;
if (val > max)
return max;
......@@ -15,17 +15,20 @@ static constexpr int clamp(int val, int min, int max) {
}
template<bool shmem = false, typename T>
__device__ __forceinline__
static T load(const T *addr) {
__device__ __forceinline__ static T load(const T *addr) {
if constexpr (shmem) {
if constexpr (sizeof(T) == 8) {
uint2 data;
asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" : "=r"(data.x), "=r"(data.y) : "l"(__cvta_generic_to_shared(addr)));
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];"
: "=r"(data.x), "=r"(data.y)
: "l"(__cvta_generic_to_shared(addr)));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 16) {
uint4 data;
asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(__cvta_generic_to_shared(addr)));
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(__cvta_generic_to_shared(addr)));
return *reinterpret_cast<T *>(&data);
}
return *addr;
......@@ -44,30 +47,32 @@ static T load(const T *addr) {
}
template<typename T>
__device__ __forceinline__
static T load_pred(const T *addr, bool pred) {
__device__ __forceinline__ static T load_pred(const T *addr, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}" : "=r"(data) : "l"(addr), "r"((int)pred));
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}"
: "=r"(data)
: "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 8) {
uint2 data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}" : "=r"(data.x), "=r"(data.y) : "l"(addr), "r"((int)pred));
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}"
: "=r"(data.x), "=r"(data.y)
: "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 16) {
uint4 data;
asm volatile (
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(addr), "r"((int)pred));
asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(addr), "r"((int)pred));
return *reinterpret_cast<T *>(&data);
}
......@@ -79,17 +84,21 @@ static T load_pred(const T *addr, bool pred) {
}
template<bool shmem = false, typename T>
__device__ __forceinline__
static void store(T *addr, T val) {
__device__ __forceinline__ static void store(T *addr, T val) {
if constexpr (shmem) {
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile ("st.shared.v2.b32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y));
asm volatile(
"st.shared.v2.b32 [%0], {%1, %2};" ::"l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y));
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile ("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"(__cvta_generic_to_shared(addr)),
"r"(data.x),
"r"(data.y),
"r"(data.z),
"r"(data.w));
return;
}
*addr = val;
......@@ -107,35 +116,41 @@ static void store(T *addr, T val) {
if constexpr (sizeof(T) == 16) {
__stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
return;
}
}
*addr = val;
}
template<typename T>
__device__ __forceinline__
static void store_pred(T *addr, T val, bool pred) {
__device__ __forceinline__ static void store_pred(T *addr, T val, bool pred) {
if constexpr (sizeof(T) == 4) {
uint32_t data = *reinterpret_cast<uint32_t *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}" :: "r"((int)pred), "l"(addr), "r"(data));
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data));
return;
}
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y));
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data.x),
"r"(data.y));
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}" :: "r"((int)pred), "l"(addr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}" ::"r"((int)pred),
"l"(addr),
"r"(data.x),
"r"(data.y),
"r"(data.z),
"r"(data.w));
return;
}
......@@ -144,198 +159,174 @@ static void store_pred(T *addr, T val, bool pred) {
}
}
__device__ __forceinline__
static float2 half22float2(half2 val) {
__device__ __forceinline__ static float2 half22float2(half2 val) {
return __half22float2(val);
}
__device__ __forceinline__
static float2 half22float2(__nv_bfloat162 val) {
__device__ __forceinline__ static float2 half22float2(__nv_bfloat162 val) {
return __bfloat1622float2(val);
}
template<typename T>
__device__ __forceinline__
static T float22half2(float2 val) = delete;
__device__ __forceinline__ static T float22half2(float2 val) = delete;
template<>
__device__ __forceinline__
half2 float22half2<half2>(float2 val) {
__device__ __forceinline__ half2 float22half2<half2>(float2 val) {
return __float22half2_rn(val);
}
template<>
__device__ __forceinline__
__nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
__device__ __forceinline__ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
return __float22bfloat162_rn(val);
}
template<typename T>
__device__ __forceinline__
static void unused_var(T &val, bool alwaysfalse) {
__device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
volatile T *ptr = nullptr;
if (alwaysfalse) {
*ptr = val;
}
}
__device__ __forceinline__
static void ldmatrix(const void *ptr, uint4 &out) {
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
: "l"(__cvta_generic_to_shared(ptr))
);
__device__ __forceinline__ static void ldmatrix(const void *ptr, uint4 &out) {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
: "l"(__cvta_generic_to_shared(ptr)));
}
template<typename T>
__device__ __forceinline__
static T movmatrix(T x) {
asm volatile ("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(*reinterpret_cast<uint32_t *>(&x)) : "r"(*reinterpret_cast<uint32_t *>(&x)));
__device__ __forceinline__ static T movmatrix(T x) {
asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(*reinterpret_cast<uint32_t *>(&x))
: "r"(*reinterpret_cast<uint32_t *>(&x)));
return x;
}
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
__device__ __forceinline__
uint32_t quantize_float2(float2 value) = delete;
__device__ __forceinline__ uint32_t quantize_float2(float2 value) = delete;
template<>
__device__ __forceinline__
uint32_t quantize_float2<4, false>(float2 value) {
__device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile ("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
return result;
}
template<>
__device__ __forceinline__
uint32_t quantize_float2<4, true>(float2 value) {
__device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile ("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
return result;
}
template<>
__device__ __forceinline__
uint32_t quantize_float2<8, false>(float2 value) {
__device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile ("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
return result;
}
__device__ __forceinline__
uint32_t quantize_float2_fp4(float2 value) {
__device__ __forceinline__ uint32_t quantize_float2_fp4(float2 value) {
uint32_t result;
asm volatile ("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }" : "=r"(result) : "f"(value.y), "f"(value.x));
asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
: "=r"(result)
: "f"(value.y), "f"(value.x));
return result;
}
__device__ __forceinline__
uint32_t quantize_float4_fp8(float4 value) {
__device__ __forceinline__ uint32_t quantize_float4_fp8(float4 value) {
uint16_t lo, hi;
asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
asm volatile ("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
return uint32_t(lo) | (uint32_t(hi) << 16);
}
__device__ __forceinline__
static float cuda_tanhf(float x) {
__device__ __forceinline__ static float cuda_tanhf(float x) {
float result;
asm ("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
asm("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_frcp(float x) {
__device__ __forceinline__ static float cuda_frcp(float x) {
float result;
asm ("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
asm("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_frsqrt(float x) {
__device__ __forceinline__ static float cuda_frsqrt(float x) {
float result;
asm ("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
asm("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_sin(float x) {
__device__ __forceinline__ static float cuda_sin(float x) {
float result;
asm ("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
asm("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_cos(float x) {
__device__ __forceinline__ static float cuda_cos(float x) {
float result;
asm ("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
asm("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_exp2(float x) {
__device__ __forceinline__ static float cuda_exp2(float x) {
float result;
asm ("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
asm("ex2.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
__forceinline__ __device__
static float cuda_sigmoidf (float a)
{
__forceinline__ __device__ static float cuda_sigmoidf(float a) {
#if USE_TANH
return fmaf (0.5, __tanhf (0.5f * a), 0.5f);
#else // USE_TANH
return fmaf(0.5, __tanhf(0.5f * a), 0.5f);
#else // USE_TANH
const float L2E = 1.442695041f; // log2(exp(1))
float t, d, e, r;
t = -L2E * a;
asm ("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
asm("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
d = e + 1.0f;
asm ("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
asm("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
return r;
#endif // USE_TANH
}
template<typename T>
__device__ __forceinline__
static T gelu_half2(T x) {
__device__ __forceinline__ static T gelu_half2(T x) {
float2 xf = half22float2(x);
float2 x3f = xf * xf * xf;
float t1 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
float t2 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
float t1 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
float t2 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
return float22half2<T>(xf * make_float2(t1, t2));
}
template<typename T>
__device__ __forceinline__
static T gelu_half(T x) {
__device__ __forceinline__ static T gelu_half(T x) {
float xf = float(x);
float x3f = xf * xf * xf;
float t = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
float t = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
return (T)(xf * t);
}
template <typename T>
__device__ __forceinline__
static T silu(const T &x) {
// x * sigmoid(x)
return (T)((float)x * cuda_sigmoidf((float)x));
// return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
template<typename T>
__device__ __forceinline__ static T silu(const T &x) {
// x * sigmoid(x)
return (T)((float)x * cuda_sigmoidf((float)x));
// return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
}
__device__ __forceinline__
static half2 h2div(half2 a, half2 b) {
__device__ __forceinline__ static half2 h2div(half2 a, half2 b) {
float2 af = half22float2(a);
float2 bf = half22float2(b);
float2 of;
......@@ -343,8 +334,7 @@ static half2 h2div(half2 a, half2 b) {
of.y = __fdividef(af.y, bf.y);
return float22half2<half2>(of);
};
__device__ __forceinline__
static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
__device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 af = half22float2(a);
float2 bf = half22float2(b);
float2 of;
......@@ -353,41 +343,37 @@ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bfloat162 b) {
return float22half2<__nv_bfloat162>(of);
};
__device__ __forceinline__
static void reduce_add(float *addr, float val) {
asm volatile ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "f"(val));
__device__ __forceinline__ static void reduce_add(float *addr, float val) {
asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
}
__device__ __forceinline__
static void reduce_add_pred(float *addr, float val, bool pred) {
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}" :: "r"((int)pred), "l"(addr), "f"(val));
__device__ __forceinline__ static void reduce_add_pred(float *addr, float val, bool pred) {
asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}" ::"r"((int)pred),
"l"(addr),
"f"(val));
}
template<int cnt, typename F>
__device__ __forceinline__
static void unrolled_loop(F &&lambda) {
auto call = [&]<int ...Is>(std::integer_sequence<int, Is...>) {
(lambda.template operator()<Is>(), ...);
};
__device__ __forceinline__ static void unrolled_loop(F &&lambda) {
auto call = [&]<int... Is>(std::integer_sequence<int, Is...>) { (lambda.template operator()<Is>(), ...); };
call(std::make_integer_sequence<int, cnt>());
}
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__ __forceinline__
static float int2float_fast(int val) {
__device__ __forceinline__ static float int2float_fast(int val) {
float fval;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=f"(fval) : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=f"(fval)
: "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
return fval - 12582912.0f;
}
template<typename To, typename From>
__device__ __forceinline__
static To bit_cast(const From &input) {
__device__ __forceinline__ static To bit_cast(const From &input) {
static_assert(sizeof(To) == sizeof(From));
// not safe but anyway
return *reinterpret_cast<const To *>(&input);
......@@ -395,20 +381,20 @@ static To bit_cast(const From &input) {
// both int2float and float2half are slow on sm_75 and before
// val in [-8192, 8191], steps of 16, round to negative inf
__device__ __forceinline__
static half2 int2half2_fast_8192(int x, int y) {
__device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
uint32_t ival;
uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo;
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
ival = ival >> 4;
// (val & 0x03FF03FF) ^ 0x76007600
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-24576.0f, -24576.0f));
}
// val in [-4096, 4095], steps of 8, round to nearest
__device__ __forceinline__
static half2 int2half2_fast_4096_rn(int x, int y) {
__device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
// x = max(min(x, 4095), -4096);
// y = max(min(y, 4095), -4096);
// TODO: round to even?
......@@ -416,24 +402,27 @@ static half2 int2half2_fast_4096_rn(int x, int y) {
y = y * 8192 + 32768;
uint32_t ival;
uint32_t hval;
// ival.lo = x.hi; ival.hi = y.hi;
// ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-12288.0f, -12288.0f));
}
// val in [-512, 511]
__device__ __forceinline__
static half2 int2half2_fast_512(int x, int y) {
__device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
uint32_t ival;
uint32_t hval;
// ival.lo = x.lo; ival.hi = y.lo;
// ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them
asm volatile ("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600
asm volatile ("lop3.b32 %0, %1, %2, %3, %4;" : "=r"(hval) : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
: "=r"(hval)
: "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
return __hadd2(kernels::bit_cast<half2>(hval), half2(-1536.0f, -1536.0f));
}
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -14,7 +14,6 @@ struct FasterI2FMode {
static bool check(bool act_unsigned);
};
template<typename F>
static void invoke_launch(Tensor::ScalarType dtype, bool use_fp4, bool fasterI2F, F &&launch) {
if (fasterI2F && dtype == Tensor::FP16) {
......@@ -32,37 +31,35 @@ static void invoke_launch(Tensor::ScalarType dtype, bool use_fp4, bool fasterI2F
}
}
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
) {
void gemm_w4a4(Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn, // linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens) {
Tensor::ScalarType dtype = Tensor::INVALID_SCALAR_TYPE;
if (!fp4) {
dtype = ascales.dtype();
......@@ -75,37 +72,35 @@ void gemm_w4a4(
}
}
invoke_launch(dtype, fp4, FasterI2FMode::check(act_unsigned), [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
wcscales,
out_q,
out_k,
out_v,
attn_tokens
);
GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
wcscales,
out_q,
out_k,
out_v,
attn_tokens);
});
}
......@@ -115,26 +110,28 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) {
});
}
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
void quantize_w4a4_act_fuse_lora(Tensor input,
Tensor output,
Tensor oscales,
Tensor lora_down,
Tensor lora_act_out,
Tensor smooth,
bool fuse_glu,
bool fp4) {
invoke_launch(input.dtype(), fp4, false, [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
);
input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4);
});
}
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, false>::quantize_w4a4_act(
input, output, oscales
);
GEMM_W4A4_Launch<Config, false>::quantize_w4a4_act(input, output, oscales);
});
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
invoke_launch(input.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
GEMM_W4A4_Launch<Config, false>::quantize_w4a4_wgt(
input, output, oscales
);
GEMM_W4A4_Launch<Config, false>::quantize_w4a4_wgt(input, output, oscales);
});
}
......@@ -143,7 +140,7 @@ bool FasterI2FMode::check(bool act_unsigned) {
if (prop->major != 7 || prop->minor != 5) {
return false;
}
if (mode == Always) {
return true;
} else if (mode == Enabled && !act_unsigned) {
......@@ -162,4 +159,4 @@ void set_faster_i2f_mode(std::string mode) {
FasterI2FMode::mode = mapping.at(mode);
}
};
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -31,8 +31,7 @@ public:
static constexpr bool FP4_AVAILABLE = false;
#endif
__device__ __forceinline__
static void trap_no_fp4() {
__device__ __forceinline__ static void trap_no_fp4() {
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
printf("FP4 is not available on this device\n");
}
......@@ -44,12 +43,12 @@ public:
static_assert(WARP_N % 32 == 0);
static_assert(WARP_M % 32 == 0);
static constexpr int WMSCALES_PACK_SIZE = clamp(WARP_N / 32, 1, 4);
static constexpr int WMSCALES_NUM_PACKS = ceilDiv(WARP_N / 32, WMSCALES_PACK_SIZE);
static constexpr int WMSCALES_PACK_SIZE = clamp(WARP_N / 32, 1, 4);
static constexpr int WMSCALES_NUM_PACKS = ceilDiv(WARP_N / 32, WMSCALES_PACK_SIZE);
static constexpr int WMSCALES_VALID_LANES = WARP_SIZE;
static constexpr int AMSCALES_PACK_SIZE = clamp(WARP_M / 32, 1, 4);
static constexpr int AMSCALES_NUM_PACKS = ceilDiv(WARP_M / 32, AMSCALES_PACK_SIZE);
static constexpr int AMSCALES_PACK_SIZE = clamp(WARP_M / 32, 1, 4);
static constexpr int AMSCALES_NUM_PACKS = ceilDiv(WARP_M / 32, AMSCALES_PACK_SIZE);
static constexpr int AMSCALES_VALID_LANES = WARP_SIZE;
struct packed_wmscale_t {
......@@ -62,48 +61,50 @@ public:
using wmscale_warp = std::array<packed_wmscale_t, WMSCALES_NUM_PACKS>;
// amscales: [M / BLOCK_M, K / group size, NUM_WARPS, AMSCALES_NUM_PACKS, WARP_SIZE] of packed_amscale_t
__device__ __forceinline__
static void load_amscale(const packed_amscale_t *ptr, int group, amscale_warp &out, bool pred) {
__device__ __forceinline__ static void
load_amscale(const packed_amscale_t *ptr, int group, amscale_warp &out, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
#pragma unroll
for (int i = 0; i < AMSCALES_NUM_PACKS; i++) {
out[i] = load_pred(&ptr[(group * NUM_WARPS + warpId) * AMSCALES_NUM_PACKS * AMSCALES_VALID_LANES + i * AMSCALES_VALID_LANES + laneId], pred);
out[i] = load_pred(&ptr[(group * NUM_WARPS + warpId) * AMSCALES_NUM_PACKS * AMSCALES_VALID_LANES +
i * AMSCALES_VALID_LANES + laneId],
pred);
}
}
// wmscales: [N / BLOCK_N, 1, K / group size, WMSCALES_NUM_PACKS, WMSCALES_VALID_LANES] of packed_wmscale_t
__device__ __forceinline__
static void load_wmscale(const packed_wmscale_t *ptr, int group, wmscale_warp &out, bool pred) {
__device__ __forceinline__ static void
load_wmscale(const packed_wmscale_t *ptr, int group, wmscale_warp &out, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
#pragma unroll
for (int i = 0; i < WMSCALES_NUM_PACKS; i++) {
out[i] = load_pred(&ptr[(group * WMSCALES_NUM_PACKS + i) * WMSCALES_VALID_LANES + laneId], pred);
}
}
__device__ __forceinline__
static void quantize_w4a4_fp4_from_fpsum_warp(const packed_fpsum_t (&fpsum)[INSN_K / INSN_N], packed_act_t &output, uint32_t &output_scale, int ida) {
__device__ __forceinline__ static void quantize_w4a4_fp4_from_fpsum_warp(
const packed_fpsum_t (&fpsum)[INSN_K / INSN_N], packed_act_t &output, uint32_t &output_scale, int ida) {
constexpr int NUM_GROUPS = 4;
static_assert(NUM_GROUPS == INSN_K / INSN_N);
constexpr float QVALUE_MAX = 6.0f;
constexpr float QVALUE_MAX = 6.0f;
constexpr float RECPI_QVALUE_MAX = 1 / QVALUE_MAX;
constexpr float MSCALE_MAX = 448.0f;
constexpr float MSCALE_MAX = 448.0f;
const int laneId = threadIdx.x % WARP_SIZE;
// 0 for row 0-7; 1 for row 8-15
// each half2_t represents a 8*8 matrix
half2_t input[2][INSN_K / INSN_N * 2];
#pragma unroll
#pragma unroll
for (int i = 0; i < INSN_K / INSN_N; i++) {
input[0][i * 2 + 0] = fpsum[i].data[0];
input[0][i * 2 + 1] = fpsum[i].data[2];
input[1][i * 2 + 0] = fpsum[i].data[1];
input[1][i * 2 + 1] = fpsum[i].data[3];
}
auto maxabs = [](half2_t val) ALWAYSINLINE {
val = __habs2(val);
return __hmax(val.x, val.y);
......@@ -111,14 +112,14 @@ public:
// each half_t represents maxvalue in a 8*16 matrix
half_t maxvalue[2][NUM_GROUPS];
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) {
maxvalue[0][i] = __hmax(maxabs(input[0][i * 2]), maxabs(input[0][i * 2 + 1]));
maxvalue[1][i] = __hmax(maxabs(input[1][i * 2]), maxabs(input[1][i * 2 + 1]));
}
#pragma unroll
#pragma unroll
for (int mask = 2; mask > 0; mask /= 2) {
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) {
maxvalue[0][i] = __hmax(maxvalue[0][i], __shfl_xor_sync(~0, maxvalue[0][i], mask));
maxvalue[1][i] = __hmax(maxvalue[1][i], __shfl_xor_sync(~0, maxvalue[1][i], mask));
......@@ -128,10 +129,10 @@ public:
float scale[2][NUM_GROUPS];
float rscale[2][NUM_GROUPS];
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) {
scale[0][i] = fminf(float(maxvalue[0][i]) * RECPI_QVALUE_MAX, MSCALE_MAX);
scale[1][i] = fminf(float(maxvalue[1][i]) * RECPI_QVALUE_MAX, MSCALE_MAX);
scale[0][i] = fminf(float(maxvalue[0][i]) * RECPI_QVALUE_MAX, MSCALE_MAX);
scale[1][i] = fminf(float(maxvalue[1][i]) * RECPI_QVALUE_MAX, MSCALE_MAX);
// TODO: check whether (1 / scale) or (1 / fp8scale) is better
rscale[0][i] = cuda_frcp(scale[0][i]);
rscale[1][i] = cuda_frcp(scale[1][i]);
......@@ -152,30 +153,29 @@ public:
if (laneId % 4 / 2 == ida) {
output_scale = (laneId % 2 == 0) ? fp8scale[0] : fp8scale[1];
}
uint32_t qpacks[2][INSN_K / INSN_M * 2];
#pragma unroll
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
#pragma unroll
for (int j = 0; j < 2; j++) {
float2 fval = half22float2(input[j][i]) * make_float2(rscale[j][i / 2], rscale[j][i / 2]);
float2 fval = half22float2(input[j][i]) * make_float2(rscale[j][i / 2], rscale[j][i / 2]);
qpacks[j][i] = quantize_float2_fp4(fval) << (laneId % 4 * 8);
}
}
#pragma unroll
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
#pragma unroll
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
#pragma unroll
for (int j = 0; j < 2; j++) {
qpacks[j][i] |= __shfl_xor_sync(~0, qpacks[j][i], mask);
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
if (laneId % 4 == i) {
output.x = qpacks[0][0 + i];
......@@ -188,88 +188,110 @@ public:
// m16n16k64 MMA
// ida, idb in {0, 1}
__device__ __forceinline__
static packed_f32psum_t mma_fp4(packed_act_t act, packed_wgt_t wgt, packed_f32psum_t psum, uint32_t amscale, uint32_t wmscale, int ida, int idb) {
__device__ __forceinline__ static packed_f32psum_t mma_fp4(packed_act_t act,
packed_wgt_t wgt,
packed_f32psum_t psum,
uint32_t amscale,
uint32_t wmscale,
int ida,
int idb) {
packed_f32psum_t out;
asm volatile (
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"(out.data[0]), "=f"(out.data[1]), "=f"(out.data[2]), "=f"(out.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3]),
"r"(amscale), "n"(0), "h"((short)ida),
"r"(wmscale), "n"(0), "h"((short)(idb * 2))
);
asm volatile (
"{%17}, {%18, %19};"
: "=f"(out.data[0]), "=f"(out.data[1]), "=f"(out.data[2]), "=f"(out.data[3])
: "r"(act.x),
"r"(act.y),
"r"(act.z),
"r"(act.w),
"r"(wgt.x),
"r"(wgt.y),
"f"(psum.data[0]),
"f"(psum.data[1]),
"f"(psum.data[2]),
"f"(psum.data[3]),
"r"(amscale),
"n"(0),
"h"((short)ida),
"r"(wmscale),
"n"(0),
"h"((short)(idb * 2)));
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"(out.data[4]), "=f"(out.data[5]), "=f"(out.data[6]), "=f"(out.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7]),
"r"(amscale), "n"(0), "h"((short)ida),
"r"(wmscale), "n"(0), "h"((short)(idb * 2 + 1))
);
"{%17}, {%18, %19};"
: "=f"(out.data[4]), "=f"(out.data[5]), "=f"(out.data[6]), "=f"(out.data[7])
: "r"(act.x),
"r"(act.y),
"r"(act.z),
"r"(act.w),
"r"(wgt.z),
"r"(wgt.w),
"f"(psum.data[4]),
"f"(psum.data[5]),
"f"(psum.data[6]),
"f"(psum.data[7]),
"r"(amscale),
"n"(0),
"h"((short)ida),
"r"(wmscale),
"n"(0),
"h"((short)(idb * 2 + 1)));
return out;
}
__device__ __forceinline__
static void compute_fp4(act_warp A, wgt_warp W, amscale_warp amscale, wmscale_warp wmscale, f32psum_warp &psum) {
__device__ __forceinline__ static void
compute_fp4(act_warp A, wgt_warp W, amscale_warp amscale, wmscale_warp wmscale, f32psum_warp &psum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
psum[i * WARP_N_TILES + j] = mma_fp4(
A[i], W[j], psum[i * WARP_N_TILES + j],
amscale[i / 2 / AMSCALES_PACK_SIZE].data[i / 2 % AMSCALES_PACK_SIZE],
wmscale[j / 2 / WMSCALES_PACK_SIZE].data[j / 2 % WMSCALES_PACK_SIZE],
i % 2, j % 2
);
psum[i * WARP_N_TILES + j] =
mma_fp4(A[i],
W[j],
psum[i * WARP_N_TILES + j],
amscale[i / 2 / AMSCALES_PACK_SIZE].data[i / 2 % AMSCALES_PACK_SIZE],
wmscale[j / 2 / WMSCALES_PACK_SIZE].data[j / 2 % WMSCALES_PACK_SIZE],
i % 2,
j % 2);
}
}
}
template<typename Epilogue, bool USE_ALPHA>
__device__ __forceinline__
static void gemm_w4a4_fp4_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_amscale_t *ascales,
const packed_wmscale_t *wscales,
float alpha, // per-tensor scale of weight
int M, int N, int K,
const Epilogue::Arguments &epilogueArgs,
bool alwaysfalse)
{
__device__ __forceinline__ static void gemm_w4a4_fp4_block(const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_amscale_t *ascales,
const packed_wmscale_t *wscales,
float alpha, // per-tensor scale of weight
int M,
int N,
int K,
const Epilogue::Arguments &epilogueArgs,
bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
act_warp A[NUM_STAGES]; // 8 * 2
wgt_warp W[NUM_STAGES]; // 32 * 2
amscale_warp amscale[NUM_STAGES]; // 1 * 2
wmscale_warp wmscale[NUM_STAGES]; // 4 * 2
f32psum_warp fpsum; // 128
act_warp A[NUM_STAGES]; // 8 * 2
wgt_warp W[NUM_STAGES]; // 32 * 2
amscale_warp amscale[NUM_STAGES]; // 1 * 2
wmscale_warp wmscale[NUM_STAGES]; // 4 * 2
f32psum_warp fpsum; // 128
for (int k = 0; k < NUM_STAGES - 1; k++) {
load_act(act, k, K, A[k], true);
......@@ -278,21 +300,21 @@ public:
load_wmscale(wscales, k, wmscale[k], true);
}
#pragma unroll
#pragma unroll
for (auto &pack : fpsum) {
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
}
}
int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
......@@ -317,15 +339,14 @@ public:
unused_var(dummy, alwaysfalse);
if constexpr (USE_ALPHA) {
#pragma unroll
#pragma unroll
for (auto &pack : fpsum) {
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
pack.data[i] *= alpha;
}
}
}
auto f16psum = packed_fp32_to_fp16(fpsum);
......@@ -337,21 +358,20 @@ public:
template<typename Epilogue, bool USE_ALPHA>
struct gemm_w4a4_fp4_kernel {
static constexpr int MIN_ARCH = 1200;
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_amscale_t *ascales,
const packed_wmscale_t *wscales,
float alpha,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
__device__ void operator()(const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_amscale_t *ascales,
const packed_wmscale_t *wscales,
float alpha,
int M,
int N,
int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse) {
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
......@@ -372,27 +392,27 @@ public:
ascales + bm * (K / WARP_K) * NUM_WARPS * AMSCALES_NUM_PACKS * AMSCALES_VALID_LANES,
wscales + bn * (K / WARP_K) * WMSCALES_NUM_PACKS * WMSCALES_VALID_LANES,
alpha,
M, N, K,
M,
N,
K,
epilogueArgs,
alwaysfalse
);
alwaysfalse);
} else {
trap_no_fp4();
}
}
};
public:
template<bool ACT_UNSIGNED>
__device__ __forceinline__
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt) {
__device__ __forceinline__ static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt) {
packed_psum_t psum;
uint4 out1 = mma_m16n8kx_s32common<mma_helper::s4u4<ACT_UNSIGNED>, mma_helper::s4>(act, uint2(wgt.x, wgt.y), uint4(0, 0, 0, 0));
uint4 out2 = mma_m16n8kx_s32common<mma_helper::s4u4<ACT_UNSIGNED>, mma_helper::s4>(act, uint2(wgt.z, wgt.w), uint4(0, 0, 0, 0));
uint4 out1 = mma_m16n8kx_s32common<mma_helper::s4u4<ACT_UNSIGNED>, mma_helper::s4>(
act, uint2(wgt.x, wgt.y), uint4(0, 0, 0, 0));
uint4 out2 = mma_m16n8kx_s32common<mma_helper::s4u4<ACT_UNSIGNED>, mma_helper::s4>(
act, uint2(wgt.z, wgt.w), uint4(0, 0, 0, 0));
psum.data[0] = out1.x;
psum.data[1] = out1.y;
psum.data[2] = out1.z;
......@@ -401,29 +421,30 @@ public:
psum.data[5] = out2.y;
psum.data[6] = out2.z;
psum.data[7] = out2.w;
return psum;
}
// template<bool si>
template<bool use_unsigned>
__device__ __forceinline__
static void quantize_w4a4_from_fpsum_warp(const packed_fpsum_t (&fpsum)[INSN_K / INSN_N], packed_act_t &output, half_t *output_scale) {
__device__ __forceinline__ static void quantize_w4a4_from_fpsum_warp(const packed_fpsum_t (&fpsum)[INSN_K / INSN_N],
packed_act_t &output,
half_t *output_scale) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr float QVALUE_MAX_SIGNED = 7.0f;
constexpr float QVALUE_MAX_UNSIGNED = 15.0f;
constexpr float RECPI_QVALUE_MAX_SIGNED = 1 / QVALUE_MAX_SIGNED;
constexpr float QVALUE_MAX_SIGNED = 7.0f;
constexpr float QVALUE_MAX_UNSIGNED = 15.0f;
constexpr float RECPI_QVALUE_MAX_SIGNED = 1 / QVALUE_MAX_SIGNED;
constexpr float RECPI_QVALUE_MAX_UNSIGNED = 1 / QVALUE_MAX_UNSIGNED;
constexpr float QVALUE_MAX = use_unsigned ? QVALUE_MAX_UNSIGNED : QVALUE_MAX_SIGNED;
constexpr float QVALUE_MAX = use_unsigned ? QVALUE_MAX_UNSIGNED : QVALUE_MAX_SIGNED;
constexpr float RECPI_QVALUE_MAX = use_unsigned ? RECPI_QVALUE_MAX_UNSIGNED : RECPI_QVALUE_MAX_SIGNED;
// constexpr int QUANTIZE_BITMASK = 0xf;
// 0 for row 0-7; 1 for row 8-15
half2_t input[2][INSN_K / INSN_N * 2];
#pragma unroll
#pragma unroll
for (int i = 0; i < INSN_K / INSN_N; i++) {
input[0][i * 2 + 0] = fpsum[i].data[0];
input[0][i * 2 + 1] = fpsum[i].data[2];
......@@ -434,14 +455,14 @@ public:
half_t maxvalue[2];
maxvalue[0] = 0;
maxvalue[1] = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
half2_t abs0 = __habs2(input[0][i]);
half2_t abs1 = __habs2(input[1][i]);
maxvalue[0] = __hmax(maxvalue[0], __hmax(abs0.x, abs0.y));
maxvalue[1] = __hmax(maxvalue[1], __hmax(abs1.x, abs1.y));
maxvalue[0] = __hmax(maxvalue[0], __hmax(abs0.x, abs0.y));
maxvalue[1] = __hmax(maxvalue[1], __hmax(abs1.x, abs1.y));
}
#pragma unroll
#pragma unroll
for (int mask = 2; mask > 0; mask /= 2) {
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor_sync(~0, maxvalue[0], mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor_sync(~0, maxvalue[1], mask));
......@@ -455,7 +476,7 @@ public:
scale[0] = float(maxvalue[0]) * RECPI_QVALUE_MAX;
scale[1] = float(maxvalue[1]) * RECPI_QVALUE_MAX;
if (laneId % 4 == 0) {
output_scale[laneId / 4] = half_t(scale[0]);
output_scale[laneId / 4] = half_t(scale[0]);
output_scale[laneId / 4 + 8] = half_t(scale[1]);
}
......@@ -466,23 +487,23 @@ public:
rscale[1] = cuda_frcp(scale[1]);
uint32_t qpacks[2][INSN_K / INSN_M * 2];
#pragma unroll
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
#pragma unroll
for (int j = 0; j < 2; j++) {
// half2_t hval = __hmul2(input[j][i], half2_t(rscale[j], rscale[j]));
// float2 fval = half22float2(hval);
float2 fval = half22float2(input[j][i]) * make_float2(rscale[j], rscale[j]);
float2 fval = half22float2(input[j][i]) * make_float2(rscale[j], rscale[j]);
qpacks[j][i] = quantize_float2<4, use_unsigned>(fval) << (laneId % 4 * 8);
}
}
// 2 * 8 * 2 = 32 instructions => 256 cycles
#pragma unroll
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
#pragma unroll
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
#pragma unroll
for (int j = 0; j < 2; j++) {
qpacks[j][i] |= __shfl_xor_sync(~0, qpacks[j][i], mask);
}
......@@ -490,7 +511,7 @@ public:
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
if (laneId % 4 == i) {
output.x = qpacks[0][0 + i];
......@@ -501,73 +522,74 @@ public:
}
}
/**
* each warp quantizes a INSN_M * INSN_K (16 * 64) matrix
* input is per-warp (in global memory)
* output is per-thread (in regs)
* output_scale is per-warp (in shared memory)
* shmem must be at least INSN_M * INSN_K * sizeof(element) (16 * 64 * 0.5 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
* default to quantize activation, if quantize weight, input should be column-majored and output should be
* transposed ({x, y, z, w} = {x, z, y, w})
*/
__device__ __forceinline__
static void quantize_w4a4_warp(const half_t *input, int stride, packed_act_t &output, half_t *output_scale, void *shmem) {
__device__ __forceinline__ static void
quantize_w4a4_warp(const half_t *input, int stride, packed_act_t &output, half_t *output_scale, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int QUANTIZE_BITWIDTH = 4;
constexpr int QVALUE_MAX = 7; // 4 bit => [-8, 7]
constexpr int QVALUE_MAX = 7; // 4 bit => [-8, 7]
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 8 for 4bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
// a pack is {a0, ..., a7} in figure
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a PACK_SIZE * 4 =
// INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 8 for 4bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>;
packed_input packs[NUM_PACKWARPS];
// load
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
packs[i] = load(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
}
// find max
half_t maxvalue[NUM_PACKWARPS];
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __habs(packs[i][0]);
#pragma unroll
maxvalue[i] = __habs(packs[i][0]);
#pragma unroll
for (int j = 1; j < PACK_SIZE; j++) {
maxvalue[i] = __hmax(maxvalue[i], __habs(packs[i][j]));
}
}
// warp reduce (max)
#pragma unroll
#pragma unroll
for (int mask = NUM_PACKS_PER_ROW / 2; mask > 0; mask /= 2) {
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor_sync(~0, maxvalue[i], mask));
}
}
// broadcast (max)
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __shfl_sync(~0, maxvalue[i], laneId / NUM_PACKS_PER_ROW * NUM_PACKS_PER_ROW);
}
// quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
half_t scale = maxvalue[i] / half_t(QVALUE_MAX);
half_t rscale = half_t(QVALUE_MAX) / maxvalue[i];
......@@ -576,13 +598,13 @@ public:
}
uint32_t qpack = 0;
// #pragma unroll
// for (int j = 0; j < PACK_SIZE; j++) {
// int intvalue = __half2int_rn(packs[i][j] / scale);
// intvalue = clamp(intvalue, -QVALUE_MAX, QVALUE_MAX);
// qpack |= (intvalue & QUANTIZE_BITMASK) << (QUANTIZE_BITWIDTH * j);
// }
#pragma unroll
// #pragma unroll
// for (int j = 0; j < PACK_SIZE; j++) {
// int intvalue = __half2int_rn(packs[i][j] / scale);
// intvalue = clamp(intvalue, -QVALUE_MAX, QVALUE_MAX);
// qpack |= (intvalue & QUANTIZE_BITMASK) << (QUANTIZE_BITWIDTH * j);
// }
#pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) {
half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
qpack |= quantize_float2<QUANTIZE_BITWIDTH, false>(half22float2(hval)) << (j * QUANTIZE_BITWIDTH);
......@@ -590,7 +612,7 @@ public:
mat[i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW][laneId % NUM_PACKS_PER_ROW] = qpack;
}
__syncwarp();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
......@@ -602,12 +624,11 @@ public:
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
struct quantize_w4a4_act_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K) {
__device__ void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int bk = blockIdx.y;
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int bk = blockIdx.y;
const int warpId = blockIdx.x % (BLOCK_M / WARP_M);
const int row = blockIdx.x * WARP_M;
......@@ -620,28 +641,27 @@ public:
packed_act_t tmpout;
quantize_w4a4_warp(
input + (row + tileId * INSN_M) * K + col,
K,
tmpout,
oscale_shmem + tileId * INSN_M,
tmp_shmem
);
input + (row + tileId * INSN_M) * K + col, K, tmpout, oscale_shmem + tileId * INSN_M, tmp_shmem);
store(&output[(((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * WARP_M_TILES + tileId) * WARP_SIZE + laneId], tmpout);
store(&output[(((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * WARP_M_TILES + tileId) * WARP_SIZE +
laneId],
tmpout);
}
// if (threadIdx.x == 0) {
// printf("Block (%d, %d) => offset = %d\n", blockIdx.x, blockIdx.y, (bm * K / WARP_K + bk) * NUM_WARPS + warpId);
// printf("Block (%d, %d) => offset = %d\n", blockIdx.x, blockIdx.y, (bm * K / WARP_K + bk) * NUM_WARPS
// + warpId);
// }
pack_ascales(oscale_shmem, &oscales[((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
pack_ascales(
oscale_shmem,
&oscales[((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
}
};
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
struct quantize_w4a4_wgt_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__
void operator()(const half_t *input, packed_wgt_t *output, packed_wscale_t *oscales, int K) {
__device__ void operator()(const half_t *input, packed_wgt_t *output, packed_wscale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
const int bn = blockIdx.x / (BLOCK_N / WARP_N);
......@@ -657,12 +677,7 @@ public:
packed_wgt_t tmpout;
quantize_w4a4_warp(
input + (col + tileId * INSN_N) * K + row,
K,
tmpout,
oscale_shmem + tileId * INSN_N,
tmp_shmem
);
input + (col + tileId * INSN_N) * K + row, K, tmpout, oscale_shmem + tileId * INSN_N, tmp_shmem);
std::swap(tmpout.y, tmpout.z);
......@@ -674,59 +689,52 @@ public:
};
struct i2f_sm80 {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
__device__ __forceinline__ static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
__device__ __forceinline__ static half2_t int2half2(int x, int y) {
return float22half2<half2_t>(int2float2(x, y));
}
};
struct i2f_sm75 {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
__device__ __forceinline__ static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
__device__ __forceinline__ static half2_t int2half2(int x, int y) {
return half2(__int2half_rn(x), __int2half_rn(y));
}
};
struct i2f_sm75_fast {
__device__ __forceinline__
static float2 int2float2(int x, int y) {
__device__ __forceinline__ static float2 int2float2(int x, int y) {
return make_float2(int2float_fast(x), int2float_fast(y));
}
__device__ __forceinline__
static half2_t int2half2(int x, int y) {
__device__ __forceinline__ static half2_t int2half2(int x, int y) {
return int2half2_fast_512(x, y);
}
};
template<bool ACT_UNSIGNED, typename T>
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
__device__ __forceinline__ static void
compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
using int2half2 = i2f_sm80;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
using int2half2 = std::conditional_t<Config::FASTER_I2F, i2f_sm75_fast, i2f_sm75>;;
#else
using int2half2 = Base::i2f_normal;
#endif
Base::template apply_scales<int2half2>([&](int i, int j) {
return mma<ACT_UNSIGNED>(A[i], W[j]);
}, ascale, wscale, fpsum);
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
using int2half2 = std::conditional_t<Config::FASTER_I2F, i2f_sm75_fast, i2f_sm75>;
;
#else
using int2half2 = Base::i2f_normal;
#endif
Base::template apply_scales<int2half2>(
[&](int i, int j) { return mma<ACT_UNSIGNED>(A[i], W[j]); }, ascale, wscale, fpsum);
}
__device__ __forceinline__
static void checkNan(fpsum_warp fpsum, const char *info = "") {
__device__ __forceinline__ static void checkNan(fpsum_warp fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
......@@ -735,14 +743,17 @@ public:
for (int j = 0; j < 4; j++) {
bool abnormal = !isfinite((float)fpsum[i].data[j].x) || !isfinite((float)fpsum[i].data[j].y);
if (abnormal) {
printf("abnormal value detected at block.x=%d block.y=%d warpId=%d laneId=%d fpsum_warp (%s) i=%d j=%d data.x=%f data.y=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
i, j,
(float)fpsum[i].data[j].x,
(float)fpsum[i].data[j].y
);
printf("abnormal value detected at block.x=%d block.y=%d warpId=%d laneId=%d fpsum_warp (%s) i=%d "
"j=%d data.x=%f data.y=%f\n",
blockIdx.x,
blockIdx.y,
warpId,
laneId,
info,
i,
j,
(float)fpsum[i].data[j].x,
(float)fpsum[i].data[j].y);
__trap();
}
}
......@@ -750,8 +761,7 @@ public:
#endif
}
__device__ __forceinline__
static void checkNan(packed_f32psum_t fpsum, const char *info = "") {
__device__ __forceinline__ static void checkNan(packed_f32psum_t fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
......@@ -759,21 +769,22 @@ public:
for (int j = 0; j < 8; j++) {
bool abnormal = !isfinite(fpsum.data[j]);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_f32psum_t (%s) j=%d data=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
printf(
"abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_f32psum_t (%s) j=%d data=%f\n",
blockIdx.x,
blockIdx.y,
warpId,
laneId,
info,
j,
fpsum.data[j]
);
fpsum.data[j]);
__trap();
}
}
#endif
}
__device__ __forceinline__
static void checkNan(packed_fpsum_t fpsum, const char *info = "") {
__device__ __forceinline__ static void checkNan(packed_fpsum_t fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
......@@ -781,34 +792,36 @@ public:
for (int j = 0; j < 4; j++) {
bool abnormal = !isfinite((float)fpsum.data[j].x) || !isfinite((float)fpsum.data[j].y);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) j=%d data.x=%f data.y=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
j,
(float)fpsum.data[j].x,
(float)fpsum.data[j].y
);
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) j=%d data.x=%f "
"data.y=%f\n",
blockIdx.x,
blockIdx.y,
warpId,
laneId,
info,
j,
(float)fpsum.data[j].x,
(float)fpsum.data[j].y);
__trap();
}
}
#endif
}
__device__ __forceinline__
static void checkNan(float data, const char *info = "") {
__device__ __forceinline__ static void checkNan(float data, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
bool abnormal = !isfinite(data);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) data=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
data
);
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) data=%f\n",
blockIdx.x,
blockIdx.y,
warpId,
laneId,
info,
data);
__trap();
}
#endif
......@@ -816,19 +829,18 @@ public:
// out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp
template<typename Epilogue, bool ACT_UNSIGNED, bool USE_FP32_ACCUM>
__device__ __forceinline__
static void gemm_w4a4_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// const packed_wscale_t *bias_ptr,
// half_t *out,
int M, int N, int K,
const Epilogue::Arguments &epilogueArgs,
bool alwaysfalse)
{
__device__ __forceinline__ static void gemm_w4a4_block(const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// const packed_wscale_t *bias_ptr,
// half_t *out,
int M,
int N,
int K,
const Epilogue::Arguments &epilogueArgs,
bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
......@@ -838,11 +850,11 @@ public:
fpsum_warp fpsum;
GEMM_W4A4_Block<Config>()(act, wgt, ascales, wscales, K, fpsum, alwaysfalse);
#else
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale[NUM_STAGES]; // 1
wscale_warp wscale[NUM_STAGES]; // 2
std::conditional_t<USE_FP32_ACCUM, f32psum_warp, fpsum_warp> fpsum; // 64
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale[NUM_STAGES]; // 1
wscale_warp wscale[NUM_STAGES]; // 2
std::conditional_t<USE_FP32_ACCUM, f32psum_warp, fpsum_warp> fpsum; // 64
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
......@@ -867,14 +879,14 @@ public:
}
}
}
int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
......@@ -889,11 +901,11 @@ public:
compute<ACT_UNSIGNED>(A[k2], W[k2], ascale[k2], wscale[k2], fpsum);
//#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
if (alwaysfalse) {
dummy = clock();
}
//#endif
// #endif
// asm volatile ("membar.cta;");
}
......@@ -927,11 +939,17 @@ public:
const packed_wscale_t *smooth_factor;
};
static constexpr int NUM_PACKS = INSN_K / INSN_N;
static constexpr int NUM_PACKS = INSN_K / INSN_N;
static constexpr int NUM_GROUPS = WARP_N_TILES / NUM_PACKS;
__device__ __forceinline__
void apply_quantize(fpsum_warp fpsum, int M, int N, int K, packed_act_t *qout, oscales_t *oscales, half_t shift_value, const packed_wscale_t *smooth_factor) {
__device__ __forceinline__ void apply_quantize(fpsum_warp fpsum,
int M,
int N,
int K,
packed_act_t *qout,
oscales_t *oscales,
half_t shift_value,
const packed_wscale_t *smooth_factor) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
......@@ -940,21 +958,21 @@ public:
wscale_warp smooth;
load_wscale(smooth_factor, 0, N, smooth, true);
#pragma unroll
#pragma unroll
for (int group = 0; group < NUM_GROUPS; group++) {
amscale_warp omscale;
#pragma unroll
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t tmp[NUM_PACKS];
#pragma unroll
#pragma unroll
for (int j = 0; j < NUM_PACKS; j++) {
half2_t ws1 = broadcast_wscale(smooth, (group * NUM_PACKS + j) * 4, laneId);
half2_t ws2 = broadcast_wscale(smooth, (group * NUM_PACKS + j) * 4 + 2, laneId);
#pragma unroll
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t src = fpsum[i * WARP_N_TILES + group * NUM_PACKS + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + group * NUM_PACKS + j].data[k];
half2_t &dst = tmp[j].data[k];
// dst.x = gelu(src.x);
......@@ -977,7 +995,8 @@ public:
packed_act_t qresult;
if constexpr (USE_FP4) {
quantize_w4a4_fp4_from_fpsum_warp(tmp, qresult, omscale[i / 2 / AMSCALES_PACK_SIZE].data[i / 2 % AMSCALES_PACK_SIZE], i % 2);
quantize_w4a4_fp4_from_fpsum_warp(
tmp, qresult, omscale[i / 2 / AMSCALES_PACK_SIZE].data[i / 2 % AMSCALES_PACK_SIZE], i % 2);
} else {
quantize_w4a4_from_fpsum_warp<USE_UNSIGNED>(tmp, qresult, &oscale_shmem[warpId][i * INSN_M]);
}
......@@ -985,34 +1004,38 @@ public:
}
if constexpr (USE_FP4) {
#pragma unroll
#pragma unroll
for (int k = 0; k < AMSCALES_NUM_PACKS; k++) {
store(&oscales[((group * NUM_WARPS + warpId) * AMSCALES_NUM_PACKS + k) * AMSCALES_VALID_LANES + laneId], omscale[k]);
store(&oscales[((group * NUM_WARPS + warpId) * AMSCALES_NUM_PACKS + k) * AMSCALES_VALID_LANES +
laneId],
omscale[k]);
}
}
if constexpr (!USE_FP4) {
__syncwarp();
pack_ascales(&oscale_shmem[warpId][0], &oscales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
pack_ascales(&oscale_shmem[warpId][0],
&oscales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
__syncwarp();
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
__device__ __forceinline__ void
operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
if constexpr (!USE_FP4 || FP4_AVAILABLE) {
apply_quantize(
fpsum, M, N, K,
args.qout + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
args.oscales + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS *
(USE_FP4 ? AMSCALES_NUM_PACKS * AMSCALES_VALID_LANES : ASCALES_NUM_PACKS * ASCALES_VALID_LANES),
args.shift_value,
args.smooth_factor + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
apply_quantize(fpsum,
M,
N,
K,
args.qout + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
args.oscales + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS *
(USE_FP4 ? AMSCALES_NUM_PACKS * AMSCALES_VALID_LANES
: ASCALES_NUM_PACKS * ASCALES_VALID_LANES),
args.shift_value,
args.smooth_factor + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES);
} else {
trap_no_fp4();
}
......@@ -1025,22 +1048,21 @@ public:
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr int MAX_ARCH = Config::FASTER_I2F ? 750 : INT_MAX; // FASTER_I2F is only needed on sm_75
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
__device__ void operator()(const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
int M,
int N,
int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse) {
// printf("Device sizeof(args) = %d", (int)sizeof(epilogueArgs));
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
......@@ -1064,20 +1086,21 @@ public:
// bias ? bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES : nullptr,
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// out + (bm * N / BLOCK_N + bn) * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE,
M, N, K,
M,
N,
K,
epilogueArgs,
alwaysfalse
);
alwaysfalse);
}
};
template<bool fuse_glu, bool use_fp4>
struct quantize_w4a4_fuse_lora_kernel {
using oscales_t = typename std::conditional_t<use_fp4, packed_amscale_t, packed_ascale_t>;
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(Base::template load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_PER_WARP =
ceilDiv<size_t>(Base::template load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments {
......@@ -1091,25 +1114,23 @@ public:
int lora_rank;
// aligned to BLOCK_M and BLOCK_N
int M, N; // N should be the actual K in the next GEMM (needs /2 if fuse_glu)
int M, N; // N should be the actual K in the next GEMM (needs /2 if fuse_glu)
// the actual M and N (no need to /2 if fuse_glu)
int actualM, actualN;
bool alwaysfalse;
};
__device__ __forceinline__
void operator()(Arguments args)
{
__device__ __forceinline__ void operator()(Arguments args) {
const BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
const int bm = binfo.bm;
const int bn = binfo.bn;
const int bm = binfo.bm;
const int bn = binfo.bn;
const int warpId = threadIdx.x / WARP_SIZE;
const int m_offset = bm * BLOCK_M + warpId * WARP_M;
......@@ -1119,42 +1140,48 @@ public:
fpsum_warp fpsum;
Base::template load_act_to_fpsum<fuse_glu>()(
args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
fpsum,
shmem + warpId * SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
Base::template load_act_to_fpsum<fuse_glu>()(args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
fpsum,
shmem + warpId * SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
CHECK_NAN(fpsum, "fpsum");
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
using EpilogueLoraDown = typename Lora<Config>::EpilogueLoraDown;
EpilogueLoraDown()(binfo, fpsum, args.M, args.N, 0, typename EpilogueLoraDown::Arguments{
.lora_wgt_down = args.lora_wgt_down,
.lora_act = args.lora_act,
.rank = args.lora_rank,
.alwaysfalse = args.alwaysfalse,
});
EpilogueQuantize<false, false, use_fp4>()(binfo, fpsum, args.M, args.N, 0, typename EpilogueQuantize<false, false, use_fp4>::Arguments{
.qout = args.output,
.oscales = args.oscales,
.shift_value = 0,
.smooth_factor = args.smooth_factor
});
EpilogueLoraDown()(binfo,
fpsum,
args.M,
args.N,
0,
typename EpilogueLoraDown::Arguments{
.lora_wgt_down = args.lora_wgt_down,
.lora_act = args.lora_act,
.rank = args.lora_rank,
.alwaysfalse = args.alwaysfalse,
});
EpilogueQuantize<false, false, use_fp4>()(
binfo,
fpsum,
args.M,
args.N,
0,
typename EpilogueQuantize<false, false, use_fp4>::Arguments{.qout = args.output,
.oscales = args.oscales,
.shift_value = 0,
.smooth_factor = args.smooth_factor});
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -5,57 +5,61 @@ namespace nunchaku::kernels {
template<typename Config, bool USE_FP4>
class GEMM_W4A4_Launch {
using GEMM = GEMM_W4A4<Config>;
using GEMM = GEMM_W4A4<Config>;
using Epilogues = Epilogues<Config>;
using Lora = Lora<Config>;
using Lora = Lora<Config>;
using packed_act_t = typename GEMM::packed_act_t;
using packed_wgt_t = typename GEMM::packed_wgt_t;
using packed_ascale_t = typename GEMM::packed_ascale_t;
using packed_wscale_t = typename GEMM::packed_wscale_t;
using packed_act_t = typename GEMM::packed_act_t;
using packed_wgt_t = typename GEMM::packed_wgt_t;
using packed_ascale_t = typename GEMM::packed_ascale_t;
using packed_wscale_t = typename GEMM::packed_wscale_t;
using packed_amscale_t = typename GEMM::packed_amscale_t;
using packed_wmscale_t = typename GEMM::packed_wmscale_t;
using packed_fpsum_t = typename GEMM::packed_fpsum_t;
using half_t = typename GEMM::half_t;
using packed_fpsum_t = typename GEMM::packed_fpsum_t;
using half_t = typename GEMM::half_t;
public:
static void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
);
static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4);
static void gemm_w4a4(Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn, // linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens);
static void quantize_w4a4_act_fuse_lora(Tensor input,
Tensor output,
Tensor oscales,
Tensor lora_down,
Tensor lora_act_out,
Tensor smooth,
bool fuse_glu,
bool fp4);
static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
static void linearattn_vk_mul_q(Tensor q, Tensor vk);
};
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, true>;
};
\ No newline at end of file
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, true>;
};
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, false>;
};
\ No newline at end of file
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_BF16, false>;
};
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, true>;
};
\ No newline at end of file
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, true>;
};
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>;
};
\ No newline at end of file
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>;
};
#include "gemm_w4a4_launch_impl.cuh"
namespace nunchaku::kernels {
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16_FasterI2F, false>;
};
\ No newline at end of file
template class GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16_FasterI2F, false>;
};
......@@ -9,36 +9,35 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
template<>
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
#endif
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn, // linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
) {
Tensor wcscales, // packed ws [N]
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens) {
#ifdef __INTELLISENSE__
static constexpr bool USE_FP4 = false;
#endif
......@@ -89,32 +88,35 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
if constexpr (!USE_FP4) {
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_act_t *,
const packed_wgt_t *,
const packed_ascale_t *,
const packed_wscale_t *,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_act_t *,
const packed_wgt_t *,
const packed_ascale_t *,
const packed_wscale_t *,
int,
int,
int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
M,
N,
K,
args,
swapBlockMN,
false
);
false);
checkCUDA(cudaGetLastError());
});
return;
......@@ -124,16 +126,18 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int,
int,
int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
......@@ -141,21 +145,22 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(),
alpha,
M, N, K,
M,
N,
K,
args,
swapBlockMN,
false
);
false);
checkCUDA(cudaGetLastError());
});
return;
}
......@@ -171,35 +176,37 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code
// on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
typename EpilogueBias::Arguments{
.bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
.scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
},
nextArgs,
{}
});
using Epilogue =
typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>(
{typename EpilogueBias::Arguments{
.bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
.scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
},
nextArgs,
{}});
});
});
};
// auto launch_bias = launch;
auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs, MidEpilogue::Arguments midArgs) {
auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs,
MidEpilogue::Arguments midArgs) {
assert(lora_up.valid() == lora_act_in.valid());
assert(lora_down.valid() == lora_act_out.valid());
const int rank_up = lora_up.valid() ? lora_up.shape[1] : 0;
const int rank_up = lora_up.valid() ? lora_up.shape[1] : 0;
const int rank_down = lora_down.valid() ? lora_down.shape[1] : 0;
if (rank_up == 0) {
assert(rank_down == 0);
return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs});
return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>(
{midArgs, nextArgs});
}
assert(rank_up % 16 == 0);
assert(lora_up.shape[0] == N);
......@@ -207,7 +214,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(lora_act_in.shape[0] == M);
assert(lora_act_in.shape[1] == rank_up);
using LoraUp = Lora;
using LoraUp = Lora;
using scale_t = typename LoraUp::scale_t;
scale_t scales;
......@@ -218,19 +225,20 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
}
if (rank_down == 0) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.rank = rank_up,
.scales = scales,
.alwaysfalse = false,
},
midArgs,
nextArgs,
{}
});
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp,
MidEpilogue,
NextEpilogue,
typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.rank = rank_up,
.scales = scales,
.alwaysfalse = false,
},
midArgs,
nextArgs,
{}});
}
// assert(rank_down == rank_up);
......@@ -246,25 +254,27 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.rank = rank_up,
.scales = scales,
.alwaysfalse = false,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.rank = rank_down,
.alwaysfalse = false,
},
nextArgs,
{}
});
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp,
MidEpilogue,
typename LoraDown::EpilogueLoraDown,
NextEpilogue,
typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.rank = rank_up,
.scales = scales,
.alwaysfalse = false,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.rank = rank_down,
.alwaysfalse = false,
},
nextArgs,
{}});
// });
};
......@@ -276,29 +286,28 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f;
constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize =
typename EpilogueQuantize::Arguments{.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename Epilogues::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
launch_lora.template
operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>,
typename Epilogues::EpilogueGelu>({typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize},
{});
} else {
launch_lora.template operator()<EpilogueQuantize, typename Epilogues::EpilogueGelu>(argsQuantize, {});
}
} else if (out_linearattn.valid()) {
assert(out_vk.valid());
......@@ -311,7 +320,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
int batch_size = out_vk.shape[0];
int num_heads = out_vk.shape[1];
int num_heads = out_vk.shape[1];
assert(isTypeMatch<half_t>(out_linearattn.dtype()));
assert(out_linearattn.ndims() == 3);
......@@ -326,12 +335,14 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
out_vk.zero_();
launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(typename Epilogue::Arguments{
.out_q = out_linearattn.data_ptr<half_t>(),
.out_vk = out_vk.data_ptr<float>(),
.num_blocks_per_batch = num_blocks_per_batch,
.actualM = M,
}, {});
launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(
typename Epilogue::Arguments{
.out_q = out_linearattn.data_ptr<half_t>(),
.out_vk = out_vk.data_ptr<float>(),
.num_blocks_per_batch = num_blocks_per_batch,
.actualM = M,
},
{});
} else if (rotary_emb.valid()) {
assert(norm_q.valid());
......@@ -342,8 +353,9 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
assert(rotary_emb.shape[2] == Epilogues::EpilogueRMSNormRope::HEAD_DIM);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 *
// GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); launch_lora.template operator()<typename
// GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
// .out = out.data_ptr<half_t>(),
// .actualM = actualM,
// .actualN = actualN,
......@@ -355,42 +367,48 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// }, {});
using EpilogueRope = typename Epilogues::EpilogueRMSNormRope;
auto argsRope = typename Epilogues::EpilogueRMSNormRope::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
.epsilon = 1e-6,
auto argsRope = typename Epilogues::EpilogueRMSNormRope::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
.epsilon = 1e-6,
};
if (out_q.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
argsRope,
typename Epilogues::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.actualM = attn_tokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
}
}, {});
launch_lora.template
operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>,
typename GEMM::EpilogueNop>(
{argsRope,
typename Epilogues::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.actualM = attn_tokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() /
sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() /
sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() /
sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
}},
{});
} else {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
}
}, {});
launch_lora
.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>,
typename GEMM::EpilogueNop>({argsRope,
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
}},
{});
}
} else if (out.valid()) {
using Epilogue = typename GEMM::EpilogueDefault;
typename Epilogue::Arguments args{
.out = out.data_ptr<half_t>(),
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
};
......@@ -410,7 +428,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
using Epilogue = typename Epilogues::EpilogueLiteLA;
int batch_size = vk.shape[0];
int num_heads = vk.shape[1];
int num_heads = vk.shape[1];
int num_tokens = q.shape[1];
assert(isTypeMatch<half_t>(q.scalar_type()));
......@@ -423,17 +441,21 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk)
BLOCK_SIZE = 128;
}
invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>(
q.data_ptr<half_t>(),
vk.data_ptr<float>(),
1e-6f,
num_tokens
);
invoke_kernel<typename Epilogue::vk_mul_q_kernel>
<<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>(
q.data_ptr<half_t>(), vk.data_ptr<float>(), 1e-6f, num_tokens);
checkCUDA(cudaGetLastError());
}
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input,
Tensor output,
Tensor oscales,
Tensor lora_down,
Tensor lora_act_out,
Tensor smooth,
bool fuse_glu,
bool fp4) {
const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1];
......@@ -475,24 +497,24 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
// input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.lora_rank = rank,
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
.alwaysfalse = false,
}
);
.lora_act = lora_act_out.data_ptr<float>(),
.lora_rank = rank,
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
.alwaysfalse = false,
});
checkCUDA(cudaGetLastError());
});
// });
......@@ -501,7 +523,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
if constexpr (USE_FP4) {
assert(false); // not implemented
assert(false); // not implemented
return;
}
......@@ -518,11 +540,7 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor o
dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(),
output.data_ptr<packed_act_t>(),
oscales.data_ptr<packed_ascale_t>(),
K
);
input.data_ptr<half_t>(), output.data_ptr<packed_act_t>(), oscales.data_ptr<packed_ascale_t>(), K);
checkCUDA(cudaGetLastError());
}
......@@ -540,19 +558,15 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor o
assert(output.ndims() == 2);
assert(output.shape[0] == N);
assert(output.shape[1] == K / 2);
assert(isTypeMatch<half_t>(oscales.dtype()));
// assert(oscales.dtype() == Tensor::FP16);
assert(oscales.numel() == N * K / GEMM::WARP_K);
dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
input.data_ptr<half_t>(),
output.data_ptr<packed_wgt_t>(),
oscales.data_ptr<packed_wscale_t>(),
K
);
input.data_ptr<half_t>(), output.data_ptr<packed_wgt_t>(), oscales.data_ptr<packed_wscale_t>(), K);
checkCUDA(cudaGetLastError());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -11,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16);
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0);
......@@ -26,21 +26,18 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}
}
);
typename kernel::Arguments{.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename Epilogue::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}});
checkCUDA(cudaGetLastError());
}
......@@ -52,7 +49,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
Tensor output = Tensor::empty_like(input);
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0);
......@@ -68,24 +65,25 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.input = input.data_ptr<GEMM::half_t>(),
.output = output.data_ptr<GEMM::half_t>(),
.M = M,
.N = N,
.actualM = M,
.actualN = N,
.argsEpilogue = typename Epilogue::Arguments{
.out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_q = out_q.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogue::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogue::packed_qkv_t>(),
.actualM = numTokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}
}
);
.strideHead_q =
int(out_q.stride(1) * out_q.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k =
int(out_k.stride(1) * out_k.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v =
int(out_v.stride(1) * out_v.scalar_size() / sizeof(GEMM::EpiloguePackQKV::packed_qkv_t)),
}});
checkCUDA(cudaGetLastError());
}
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -17,24 +17,22 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
assert(oscales.numel() == M * 1);
auto launch = [&]<bool FUSE_GLU>() {
using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>;
assert(kernel::check(M, K));
dim3 grid = kernel::gridSize(M, K);
dim3 grid = kernel::gridSize(M, K);
dim3 block = kernel::blockSize(M, K);
auto func = invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
auto func =
invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));
func<<<grid, block, kernel::smemSize(M, K)>>>(
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
K,
false
);
func<<<grid, block, kernel::smemSize(M, K)>>>(input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
K,
false);
checkCUDA(cudaGetLastError());
};
......@@ -45,14 +43,12 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
}
}
void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales, // [1, N]
Tensor bias
)
{
void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales, // [1, N]
Tensor bias) {
using GEMM = GEMM_W8A8;
int M = act.numel() / act.shape[-1];
......@@ -78,16 +74,18 @@ void gemm_w8a8(Tensor act, // [M, K]
std::swap(grid.x, grid.y);
}
invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>><<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M, N, K, args,
swapBlockMN,
false
);
invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>
<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M,
N,
K,
args,
swapBlockMN,
false);
checkCUDA(cudaGetLastError());
};
......@@ -98,20 +96,19 @@ void gemm_w8a8(Tensor act, // [M, K]
assert(bias.numel() == N);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on
// Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({
GEMM::EpilogueBias<true, false>::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(),
},
nextArgs,
{}
});
return launch.template operator()<Epilogue>({GEMM::EpilogueBias<true, false>::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(),
},
nextArgs,
{}});
};
launch_bias.template operator()<GEMM::EpilogueDefault>(GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<GEMM::half_t>(),
.out = out.data_ptr<GEMM::half_t>(),
.actualM = actualM,
.actualN = actualN,
});
......@@ -152,9 +149,9 @@ void gemm_w8a8_fuse_litela(
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *,
// GEMM::half_t *,
......@@ -178,7 +175,7 @@ void gemm_w8a8_fuse_litela(
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// nullptr,
M, N, K, epilogueArgs,
M, N, K, epilogueArgs,
swapBlockMN,
false
);
......@@ -193,4 +190,4 @@ void gemm_w8a8_fuse_litela(
}
#endif
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -8,48 +8,52 @@ class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
public:
using psum_warp = std::array<packed_psum_t, WARP_M_TILES * WARP_N_TILES>;
__device__ __forceinline__
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
__device__ __forceinline__ static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
// packed_psum_t psum;
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
: "r"(act.x),
"r"(act.y),
"r"(act.z),
"r"(act.w),
"r"(wgt.x),
"r"(wgt.y),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[0]),
"r"(psum.data[1]),
"r"(psum.data[2]),
"r"(psum.data[3]));
asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
: "r"(act.x),
"r"(act.y),
"r"(act.z),
"r"(act.w),
"r"(wgt.z),
"r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[4]),
"r"(psum.data[5]),
"r"(psum.data[6]),
"r"(psum.data[7]));
return psum;
}
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, psum_warp &psum) {
__device__ __forceinline__ static void compute(act_warp A, wgt_warp W, psum_warp &psum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
psum[i * WARP_N_TILES + j] = mma(A[i], W[j], psum[i * WARP_N_TILES + j]);
}
......@@ -62,11 +66,12 @@ public:
* oscales is per-warp (in shared memory)
* output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
* default to quantize activation, if quantize weight, input should be column-majored and output should be
* transposed ({x, y, z, w} = {x, z, y, w})
*/
template<bool input_shmem = false>
__device__ __forceinline__
static void quantize_w8a8_warp(const half_t *input, const half_t *oscales, int stride, packed_act_t &output, void *shmem) {
__device__ __forceinline__ static void
quantize_w8a8_warp(const half_t *input, const half_t *oscales, int stride, packed_act_t &output, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int QUANTIZE_BITWIDTH = 8;
......@@ -75,28 +80,29 @@ public:
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 4 for 8bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
// a pack is {a0, ..., a7} in figure
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a PACK_SIZE * 4 =
// INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 4 for 8bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>;
packed_input packs[NUM_PACKWARPS];
// load
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load<input_shmem>(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
packs[i] = load<input_shmem>(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
}
// quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
const int row = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
const int col = laneId % NUM_PACKS_PER_ROW;
......@@ -104,7 +110,7 @@ public:
float rscale = cuda_frcp(float(oscales[row]));
uint32_t qpack = 0;
#pragma unroll
#pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) {
// half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
float2 fval = half22float2(half2_t(packs[i][j], packs[i][j + 1])) * float2(rscale, rscale);
......@@ -113,7 +119,7 @@ public:
mat[row][col] = qpack;
}
__syncwarp();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
......@@ -126,20 +132,20 @@ public:
* each warp finds absmax from a row
*/
template<bool fuse_glu = false>
__device__ __forceinline__
static half_t findmax_warp(const half_t *input, half_t *output_shmem, int K, bool alwaysfalse) {
__device__ __forceinline__ static half_t
findmax_warp(const half_t *input, half_t *output_shmem, int K, bool alwaysfalse) {
const int laneId = threadIdx.x % WARP_SIZE;
using packed_input = std::array<half2_t, 4>;
using packed_input = std::array<half2_t, 4>;
using packed_gated_input = std::array<half_t, 4>;
constexpr int PACK_SIZE = sizeof(packed_input) / sizeof(half_t);
constexpr int PACK_SIZE = sizeof(packed_input) / sizeof(half_t);
constexpr int NUM_STAGES = 2;
half2_t maxvalue2 = { 0, 0 };
half2_t maxvalue2 = {0, 0};
packed_input pack[NUM_STAGES];
#pragma unroll
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
const int idx = k * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
if (idx < K) {
......@@ -155,11 +161,11 @@ public:
// TODO: store quantized data to shmem (instead of half)
for (int k1 = 0; k1 < ceilDiv(K, PACK_SIZE * WARP_SIZE); k1 += NUM_STAGES) {
#pragma unroll
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
const int nextidx = (k1 + k2 + NUM_STAGES - 1) * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
const int nextk2 = (k2 + NUM_STAGES - 1) % NUM_STAGES;
const int nextk2 = (k2 + NUM_STAGES - 1) % NUM_STAGES;
if (nextidx < K) {
pack[nextk2] = load(reinterpret_cast<const packed_input *>(&input[nextidx]));
......@@ -172,11 +178,11 @@ public:
if constexpr (fuse_glu) {
packed_gated_input gated;
#pragma unroll
#pragma unroll
for (int j = 0; j < p.size(); j++) {
gated[j] = p[j].x * gelu_half(p[j].y);
p[j].x = gated[j];
p[j].y = 0;
p[j].x = gated[j];
p[j].y = 0;
}
int idx = (k1 + k2) * PACK_SIZE / 2 * WARP_SIZE + laneId * PACK_SIZE / 2;
......@@ -185,7 +191,7 @@ public:
}
}
#pragma unroll
#pragma unroll
for (int j = 0; j < p.size(); j++) {
maxvalue2 = __hmax2(maxvalue2, __habs2(p[j]));
}
......@@ -194,7 +200,7 @@ public:
// unused_var(dummy, alwaysfalse);
#pragma unroll
#pragma unroll
for (int mask = 32 / 2; mask > 0; mask /= 2) {
maxvalue2 = __hmax2(maxvalue2, __shfl_xor_sync(~0, maxvalue2, mask));
}
......@@ -223,8 +229,8 @@ public:
return INSN_M * K2 * sizeof(half_t);
}
__device__
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K, bool alwaysfalse) {
__device__ void
operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K, bool alwaysfalse) {
// for quantize kernel
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
......@@ -232,10 +238,9 @@ public:
const int numWarps = blockDim.x / WARP_SIZE;
// for GEMM kernel
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int gemmWarpId = blockIdx.x % (BLOCK_M / WARP_M);
__shared__ alignas(128) half_t oscale_shmem[WARP_M];
// __shared__ alignas(128) half_t maxv_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[NUM_WARPS][512];
......@@ -249,7 +254,7 @@ public:
for (int tileM = 0; tileM < WARP_M_TILES; tileM++) {
for (int i = warpId; i < INSN_M; i += numWarps) {
const int rowLocal = tileM * INSN_M + i;
const int rowLocal = tileM * INSN_M + i;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
half_t maxv = findmax_warp<fuse_glu>(input + rowGlobal * K, shmem + i * K2, K, alwaysfalse);
......@@ -260,76 +265,66 @@ public:
__syncthreads();
for (int bk = warpId; bk < K2 / WARP_K; bk += numWarps) {
const int rowLocal = tileM * INSN_M;
const int rowLocal = tileM * INSN_M;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
const int col = bk * WARP_K;
const int col = bk * WARP_K;
packed_act_t tmpout;
if constexpr (fuse_glu) {
quantize_w8a8_warp<true>(
shmem + col,
oscale_shmem + rowLocal,
K2,
tmpout,
&tmp_shmem[warpId]
);
quantize_w8a8_warp<true>(shmem + col, oscale_shmem + rowLocal, K2, tmpout, &tmp_shmem[warpId]);
} else {
quantize_w8a8_warp<false>(
input + rowGlobal * K + col,
oscale_shmem + rowLocal,
K,
tmpout,
&tmp_shmem[warpId]
);
input + rowGlobal * K + col, oscale_shmem + rowLocal, K, tmpout, &tmp_shmem[warpId]);
}
store(&output[(((bm * K2 / WARP_K + bk) * NUM_WARPS + gemmWarpId) * WARP_M_TILES + tileM) * WARP_SIZE + laneId], tmpout);
store(&output[(((bm * K2 / WARP_K + bk) * NUM_WARPS + gemmWarpId) * WARP_M_TILES + tileM) *
WARP_SIZE +
laneId],
tmpout);
}
__syncthreads();
}
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales(oscale_shmem, &oscales[(bm * NUM_WARPS + gemmWarpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
pack_ascales(oscale_shmem,
&oscales[(bm * NUM_WARPS + gemmWarpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
}
};
__device__ __forceinline__
static gated_fpsum_warp apply_glu(fpsum_warp fpsum) {
__device__ __forceinline__ static gated_fpsum_warp apply_glu(fpsum_warp fpsum) {
gated_fpsum_warp result;
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
for (int k = 0; k < 4; k++) {
half_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst = src.x * gelu_half(src.y);
dst = src.x * gelu_half(src.y);
}
}
}
return result;
}
static constexpr int unpack_gated_fpsum_shmem_size = INSN_M * (WARP_N / 2 + 8) * sizeof(half_t);
__device__ __forceinline__
static void unpack_gated_fpsum(gated_fpsum_warp fpsum, half_t *output, int stride, void *shmem) {
__device__ __forceinline__ static void
unpack_gated_fpsum(gated_fpsum_warp fpsum, half_t *output, int stride, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int PACK_SIZE = WARP_N / 2 / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>;
using pack_t = std::array<half_t, PACK_SIZE>;
// +8 to prevent bank conflicts
using matrix_t = half_t[INSN_M][WARP_N / 2 + 8];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
packed_gated_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 + j * INSN_N / 2;
*reinterpret_cast<half_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half_t *>(&mat[row][col + 4]) = fsum.data[2];
packed_gated_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 + j * INSN_N / 2;
*reinterpret_cast<half_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half_t *>(&mat[row][col + 4]) = fsum.data[2];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[1];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[3];
}
......@@ -345,28 +340,27 @@ public:
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template<typename Epilogue>
__device__ __forceinline__
static void gemm_w8a8_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogeParams,
bool alwaysfalse)
{
__device__ __forceinline__ static void gemm_w8a8_block(const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// half_t *out,
int M,
int N,
int K,
Epilogue::Arguments epilogeParams,
bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale; // 1
wscale_warp wscale; // 2
psum_warp psum; // 128
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale; // 1
wscale_warp wscale; // 2
psum_warp psum; // 128
for (auto &pack : psum) {
for (int i = 0; i < 8; i++) {
......@@ -377,7 +371,7 @@ public:
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
load_ascale(ascales, 0, M, ascale, true);
load_wscale(wscales, 0, N, wscale, true);
......@@ -385,14 +379,14 @@ public:
load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true);
}
int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
......@@ -421,17 +415,15 @@ public:
f32psum_warp f32psum;
#pragma unroll
#pragma unroll
for (int i = 0; i < f32psum.size(); i++) {
#pragma unroll
#pragma unroll
for (int j = 0; j < 8; j++) {
f32psum[i].data[j] = 0;
}
}
apply_scales([&](int i, int j) {
return psum[i * WARP_N_TILES + j];
}, ascale, wscale, f32psum);
apply_scales([&](int i, int j) { return psum[i * WARP_N_TILES + j]; }, ascale, wscale, f32psum);
fpsum_warp fpsum = packed_fp32_to_fp16(f32psum);
......@@ -443,27 +435,24 @@ public:
Epilogue()(binfo, fpsum, M, N, K, epilogeParams);
}
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue>
struct gemm_w8a8_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
__device__ void operator()(const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// half_t *out,
int M,
int N,
int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse) {
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
......@@ -476,25 +465,25 @@ public:
const int bm = binfo.bm;
const int bn = binfo.bn;
gemm_w8a8_block<Epilogue>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (1) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES, // only 1 group in W8A8
wscales + bn * (1) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
M, N, K,
epilogueArgs,
alwaysfalse
);
gemm_w8a8_block<Epilogue>(binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (1) * NUM_WARPS * ASCALES_NUM_PACKS *
ASCALES_VALID_LANES, // only 1 group in W8A8
wscales + bn * (1) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
// #if 1
// out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// #else
// out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
// #endif
M,
N,
K,
epilogueArgs,
alwaysfalse);
}
};
#if 0
struct EpilogueGLU {
struct Arguments { size_t unused; };
......@@ -510,9 +499,6 @@ public:
}
};
#endif
};
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -2,7 +2,6 @@
#include "gemm_base.cuh"
namespace nunchaku::kernels {
template<typename Config>
......@@ -21,7 +20,7 @@ public:
public:
static constexpr int MAX_RANK = 1024;
static constexpr int WARP_R = 16;
static constexpr int WARP_R = 16;
// static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
......@@ -30,57 +29,57 @@ public:
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, MAX_RANK / 16>;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
__device__ __forceinline__
static void load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) {
__device__ __forceinline__ static void
load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = &ptr[rtile * LORA_R_TILES * WARP_SIZE + laneId];
const int stride_ntile = rank / 16 * WARP_SIZE;
const int stride_ntile = rank / 16 * WARP_SIZE;
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int roffset = r * WARP_SIZE;
const int noffset = n * stride_ntile;
constexpr int roffset = r * WARP_SIZE;
const int noffset = n * stride_ntile;
result[n * LORA_R_TILES + r] = load_pred(ptr_lane + noffset + roffset, pred);
});
});
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static void load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) {
__device__ __forceinline__ static void
load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
const float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
const float *ptrlane =
&ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
unrolled_loop<LORA_R_TILES>([&]<int r> {
constexpr int i = m * LORA_R_TILES + r;
unrolled_loop<8>([&]<int j>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r];
result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r];
});
// CHECK_NAN(tmp, "load_lora_act.tmp");
});
});
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) {
__device__ __forceinline__ static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
......@@ -108,7 +107,6 @@ public:
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
......@@ -120,19 +118,23 @@ public:
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, const float *act, const packed_fpsum_t *wgt, const scale_t &scales, int rank, bool alwaysfalse) {
__device__ __forceinline__ static void apply_lora_up(fpsum_warp &fpsum,
const float *act,
const packed_fpsum_t *wgt,
const scale_t &scales,
int rank,
bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act[NUM_STAGES]; // 32
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
lora_act_warp lora_act[NUM_STAGES]; // 32
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
int dummy = 0;
#pragma unroll
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
const bool pred = k == 0 ? true : k < rank / WARP_R;
......@@ -140,14 +142,14 @@ public:
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128
f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128
auto compute = [&scales](lora_act_warp A, lora_wgt_warp W, f32psum_warp &f32psum, int rtile) ALWAYSINLINE {
lora_act16_warp A_fp16;
for (int m = 0; m < LORA_M_TILES; m++) {
for (int r = 0; r < LORA_R_TILES; r++) {
packed_f32psum_t pack = A[m * LORA_R_TILES + r];
#pragma unroll
#pragma unroll
for (int j = 0; j < 8; j++) {
pack.data[j] *= scales[rtile * LORA_R_TILES + r];
}
......@@ -159,28 +161,28 @@ public:
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]);
f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(
A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]);
}
}
}
};
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
act += kernels::bit_cast<int>(lora_act[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
......@@ -194,25 +196,24 @@ public:
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing (packed_fp16_to_fp32)
// add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
// the branch splits the basic blocks and prevents the overlap of memory access and computing
// (packed_fp16_to_fp32) add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
#pragma unroll
for (auto &&data : lora_act[k]) {
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
#pragma unroll
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
......@@ -220,21 +221,20 @@ public:
fpsum = packed_fp32_to_fp16(f32psum);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
__device__ __forceinline__ void
operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum");
apply_lora_up(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.scales,
args.rank,
args.alwaysfalse
);
apply_lora_up(fpsum,
args.lora_act +
bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.scales,
args.rank,
args.alwaysfalse);
CHECK_NAN(fpsum, "fpsum");
}
......@@ -250,16 +250,16 @@ public:
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) {
__device__ __forceinline__ static void
apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
#pragma unroll
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
bool pred = k == 0 ? true : k < rank / WARP_R;
......@@ -270,11 +270,11 @@ public:
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
#pragma unroll
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
......@@ -294,14 +294,14 @@ public:
int dummy = 0;
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
......@@ -324,38 +324,33 @@ public:
}
}
#pragma unroll
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
__device__ __forceinline__ void
operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
apply_lora_down(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.rank,
args.alwaysfalse
);
apply_lora_down(fpsum,
args.lora_act +
bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.rank,
args.alwaysfalse);
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -7,183 +7,169 @@
namespace nunchaku::kernels {
namespace mma_helper {
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
}; // namespace mma_helper
__device__ __forceinline__
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
__device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
: "=r"(d.x), "=r"(d.y)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
asm volatile("{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
: "=r"(d.x), "=r"(d.y)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#endif
return d;
}
template<bool is_bf16>
__device__ __forceinline__
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
__device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) {
uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"C"(mma_helper::f16bf16<is_bf16>::value)
);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.%14.%14.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"C"(mma_helper::f16bf16<is_bf16>::value));
#else
static_assert(!is_bf16);
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#endif
return d;
}
template<typename AType, typename BType>
__device__ __forceinline__
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) {
__device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = (std::is_same_v<AType, mma_helper::s4> || std::is_same_v<AType, mma_helper::u4>) ? 64 : 32;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K),
"C"(AType::value),
"C"(BType::value)
);
asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K),
"C"(AType::value),
"C"(BType::value));
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2),
"C"(AType::value),
"C"(BType::value)
);
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K / 2),
"C"(AType::value),
"C"(BType::value));
#endif
return d;
}
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -6,156 +6,118 @@
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
namespace nunchaku::kernels {
namespace mma_helper {
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
struct f32 {
static constexpr const char value[] = "f32";
};
struct f16 {
static constexpr const char value[] = "f16";
};
struct bf16 {
static constexpr const char value[] = "bf16";
};
struct s32 {
static constexpr const char value[] = "s32";
};
struct s4 {
static constexpr const char value[] = "s4";
};
struct u4 {
static constexpr const char value[] = "u4";
};
template<bool is_bf16>
using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
template<bool is_unsigned>
using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
}; // namespace mma_helper
__device__ __forceinline__
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
__device__ __forceinline__ static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, %9};\n"
: "=r"(d.x), "=r"(d.y)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
:
"=r"(d.x), "=r"(d.y)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y)
);
asm volatile("{"
".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
"{%2, %3},"
"{%6},"
"{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
"{%4, %5},"
"{%7},"
"{tmp0, tmp1};"
"}\n"
: "=r"(d.x), "=r"(d.y)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#endif
return d;
}
template<bool is_bf16>
__device__ __forceinline__
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete;
__device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
__device__ __forceinline__
uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
__device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
uint4 d;
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
return d;
}
#endif
template<>
__device__ __forceinline__
uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
__device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
);
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
"{%8},"
"{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%6, %7},"
"{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#endif
return d;
}
template<typename AType, typename BType>
__device__ __forceinline__
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) = delete;
__device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) = delete;
template<>
__device__ __forceinline__
uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
__device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = 64;
......@@ -166,54 +128,50 @@ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
);
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2)
);
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K / 2));
#endif
return d;
}
template<>
__device__ __forceinline__
uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
__device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
uint4 d;
static constexpr int K = 64;
......@@ -224,50 +182,46 @@ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, ui
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K)
);
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
asm volatile(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
:
"=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
"r"(b.x), "r"(b.y),
"r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
"n"(K / 2)
);
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
"{%4},"
"{%8},"
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
"{%5},"
"{%8},"
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
"{%6},"
"{%9},"
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
"{%7},"
"{%9},"
"{tmp2, tmp3};\n"
"}\n"
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
: "r"(a.x),
"r"(a.y),
"r"(a.z),
"r"(a.w),
"r"(b.x),
"r"(b.y),
"r"(c.x),
"r"(c.y),
"r"(c.z),
"r"(c.w),
"n"(K / 2));
#endif
return d;
}
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
......@@ -5,50 +5,55 @@
namespace nunchaku::kernels {
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens
);
void gemm_w4a4(Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim]
Tensor out_linearattn, // linear [B, (M), N / 3]
bool act_unsigned,
std::vector<float> lora_scales, // [R / 16]
bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales,
Tensor out_q, // packed attention [B, H, M, D]
Tensor out_k, // packed attention [B, H, M, D]
Tensor out_v, // packed attention [B, H, M, D]
int attn_tokens);
void linearattn_vk_mul_q(Tensor q, Tensor vk);
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false, bool fp4 = false);
void quantize_w4a4_act_fuse_lora(Tensor input,
Tensor output,
Tensor oscales,
Tensor lora_down,
Tensor lora_act_out,
Tensor smooth = {},
bool fuse_glu = false,
bool fp4 = false);
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales, // [1, N]
Tensor bias // packed ws [N]
);
void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales, // [1, N]
Tensor bias // packed ws [N]
);
void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu);
......@@ -61,13 +66,11 @@ void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_gl
// Tensor wscales // [1, N]
// );
void attention_fp16(
Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale
);
void attention_fp16(Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
float scale);
// EXPERIMENTAL, for sm_75
void set_faster_i2f_mode(std::string mode);
......@@ -76,4 +79,4 @@ void set_faster_i2f_mode(std::string mode);
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens);
}; // namespace nunchaku::kernels
\ No newline at end of file
}; // namespace nunchaku::kernels
#include "layernorm.h"
#include "kernels/layernorm_kernels.h"
LayerNorm::LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), eps(eps)
{
LayerNorm::LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device)
: hidden_size(hidden_size), eps(eps) {
if (elementwise_affine) {
weight = Tensor::allocate({hidden_size}, dtype, device);
bias = Tensor::allocate({hidden_size}, dtype, device);
bias = Tensor::allocate({hidden_size}, dtype, device);
}
registerParams
(weight, "weight")
(bias, "bias")
;
registerParams(weight, "weight")(bias, "bias");
}
Tensor LayerNorm::forward(Tensor x) {
......@@ -27,10 +23,23 @@ Tensor RMSNorm::forward(Tensor x) {
return out;
}
void RMSNormGeneral::forward_with_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
rms_norm_general_fuse_sum(quantized_hidden_states_buffer, x, this->weight, quantized_sum_buffer, quantized_scale_buffer, variance_epsilon, use_per_token_quant);
void RMSNormGeneral::forward_with_act_sum(Tensor x,
Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer) {
rms_norm_general_fuse_sum(quantized_hidden_states_buffer,
x,
this->weight,
quantized_sum_buffer,
quantized_scale_buffer,
variance_epsilon,
use_per_token_quant);
}
void RMSNormGeneral::forward_wo_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
rms_norm_general(quantized_hidden_states_buffer, x, this->weight, quantized_scale_buffer, variance_epsilon, use_per_token_quant);
void RMSNormGeneral::forward_wo_act_sum(Tensor x,
Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer) {
rms_norm_general(
quantized_hidden_states_buffer, x, this->weight, quantized_scale_buffer, variance_epsilon, use_per_token_quant);
}
......@@ -20,9 +20,8 @@ private:
class RMSNorm : public Module {
public:
RMSNorm(int hidden_size, float eps, bool use_quant, Tensor::ScalarType dtype, Device device) :
use_quant(use_quant), variance_epsilon(eps)
{
RMSNorm(int hidden_size, float eps, bool use_quant, Tensor::ScalarType dtype, Device device)
: use_quant(use_quant), variance_epsilon(eps) {
weight = Tensor::allocate({hidden_size}, dtype, device);
registerParams(weight, "weight");
}
......@@ -36,13 +35,16 @@ public:
class RMSNormGeneral {
friend class LlamaDecoderLayer;
public:
RMSNormGeneral(int hidden_size, bool act_sum, float eps, bool use_per_token_quant, Device device)
: act_sum(act_sum), use_per_token_quant(use_per_token_quant), variance_epsilon(eps)
{
RMSNormGeneral(int hidden_size, bool act_sum, float eps, bool use_per_token_quant, Device device)
: act_sum(act_sum), use_per_token_quant(use_per_token_quant), variance_epsilon(eps) {
this->weight = Tensor::ones({hidden_size}, Tensor::FP32, device);
}
void forward(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
void forward(Tensor x,
Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer) {
if (act_sum) {
forward_with_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer);
} else {
......@@ -51,12 +53,18 @@ public:
}
private:
void forward_with_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
void forward_wo_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
void forward_with_act_sum(Tensor x,
Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer);
void forward_wo_act_sum(Tensor x,
Tensor quantized_hidden_states_buffer,
Tensor quantized_scale_buffer,
Tensor quantized_sum_buffer);
private:
const bool act_sum;
const bool use_per_token_quant;
const float variance_epsilon;
Tensor weight;
};
\ No newline at end of file
};
......@@ -4,103 +4,106 @@
#include "Tensor.h"
namespace pytorch_compat {
inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
assert (cond);
}
inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
assert(cond);
}
template<typename T>
inline void C10_CUDA_CHECK(T ret) {
return checkCUDA(ret);
}
template<typename T>
inline void C10_CUDA_CHECK(T ret) {
return checkCUDA(ret);
}
namespace at {
using ::Tensor;
constexpr auto kFloat32 = Tensor::FP32;
constexpr auto kFloat = Tensor::FP32;
constexpr auto kFloat16 = Tensor::FP16;
constexpr auto kBFloat16 = Tensor::BF16;
constexpr auto kInt32 = Tensor::INT32;
constexpr auto kInt64 = Tensor::INT64;
struct Generator {
Generator() { throw std::runtime_error("Not implemented"); }
std::mutex mutex_;
};
namespace cuda {
using ::getCurrentDeviceProperties;
struct StreamWrapper {
cudaStream_t st;
cudaStream_t stream() const { return st; }
};
inline StreamWrapper getCurrentCUDAStream() {
return StreamWrapper(::getCurrentCUDAStream());
}
struct CUDAGuard {
int dev;
};
namespace detail {
inline Generator getDefaultCUDAGenerator() {
return Generator();
}
}
}
using CUDAGeneratorImpl = Generator;
template<typename T>
std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
throw std::runtime_error("Not implemented");
}
}
namespace at {
using ::Tensor;
namespace torch {
using at::kFloat32;
using at::kFloat;
using at::kFloat16;
using at::kBFloat16;
using at::kInt32;
using at::kInt64;
constexpr Device kCUDA = Device::cuda();
using IntArrayRef = std::vector<int>;
using TensorOptions = Tensor::TensorOptions;
inline Tensor empty_like(const Tensor &tensor) {
return Tensor::empty_like(tensor);
}
inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device());
}
inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device()).zero_();
}
namespace nn {
namespace functional {
using PadFuncOptions = std::vector<int>;
inline Tensor pad(Tensor x, PadFuncOptions options) {
throw std::runtime_error("Not implemented");
}
}
}
namespace indexing {
constexpr int None = 0;
struct Slice {
int a;
int b;
};
}
constexpr auto kFloat32 = Tensor::FP32;
constexpr auto kFloat = Tensor::FP32;
constexpr auto kFloat16 = Tensor::FP16;
constexpr auto kBFloat16 = Tensor::BF16;
constexpr auto kInt32 = Tensor::INT32;
constexpr auto kInt64 = Tensor::INT64;
struct Generator {
Generator() {
throw std::runtime_error("Not implemented");
}
std::mutex mutex_;
};
namespace cuda {
using ::getCurrentDeviceProperties;
namespace c10 {
using std::optional;
struct StreamWrapper {
cudaStream_t st;
cudaStream_t stream() const {
return st;
}
};
inline StreamWrapper getCurrentCUDAStream() {
return StreamWrapper(::getCurrentCUDAStream());
}
struct CUDAGuard {
int dev;
};
namespace detail {
inline Generator getDefaultCUDAGenerator() {
return Generator();
}
} // namespace detail
} // namespace cuda
using CUDAGeneratorImpl = Generator;
template<typename T>
std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
throw std::runtime_error("Not implemented");
}
} // namespace at
namespace torch {
using at::kFloat32;
using at::kFloat;
using at::kFloat16;
using at::kBFloat16;
using at::kInt32;
using at::kInt64;
constexpr Device kCUDA = Device::cuda();
using IntArrayRef = std::vector<int>;
using TensorOptions = Tensor::TensorOptions;
inline Tensor empty_like(const Tensor &tensor) {
return Tensor::empty_like(tensor);
}
inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device());
}
inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device()).zero_();
}
namespace nn {
namespace functional {
using PadFuncOptions = std::vector<int>;
inline Tensor pad(Tensor x, PadFuncOptions options) {
throw std::runtime_error("Not implemented");
}
} // namespace functional
} // namespace nn
namespace indexing {
constexpr int None = 0;
struct Slice {
int a;
int b;
};
} // namespace indexing
} // namespace torch
namespace c10 {
using std::optional;
}
} // namespace pytorch_compat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment