Commit 54e6d065 authored by muyangli's avatar muyangli
Browse files

[major] support NVFP4; upgrade to 0.1

parent c7f41661
#pragma once #pragma once
#include "gemm_base.cuh" #include "gemm_base.cuh"
// #include "gemm_w4a4_block.cuh"
namespace nunchaku::kernels { namespace nunchaku::kernels {
...@@ -19,6 +20,369 @@ class GEMM_W4A4<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> { ...@@ -19,6 +20,369 @@ class GEMM_W4A4<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
public: public:
IMPORT_GEMM_BASE(Config); IMPORT_GEMM_BASE(Config);
public:
// micro-scales for FP4 MMA
// each uint32_t is a 4*32 matrix of scales (for MMA of 64*32)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200
static constexpr bool FP4_AVAILABLE = true;
#else
static constexpr bool FP4_AVAILABLE = false;
#endif
__device__ __forceinline__
static void trap_no_fp4() {
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
printf("FP4 is not available on this device\n");
}
__syncthreads();
__nanosleep(1000000);
__trap();
}
static_assert(WARP_N % 32 == 0);
static_assert(WARP_M % 32 == 0);
static constexpr int WMSCALES_PACK_SIZE = clamp(WARP_N / 32, 1, 4);
static constexpr int WMSCALES_NUM_PACKS = ceilDiv(WARP_N / 32, WMSCALES_PACK_SIZE);
static constexpr int WMSCALES_VALID_LANES = WARP_SIZE;
static constexpr int AMSCALES_PACK_SIZE = clamp(WARP_M / 32, 1, 4);
static constexpr int AMSCALES_NUM_PACKS = ceilDiv(WARP_M / 32, AMSCALES_PACK_SIZE);
static constexpr int AMSCALES_VALID_LANES = WARP_SIZE;
struct packed_wmscale_t {
uint32_t data[WMSCALES_PACK_SIZE];
};
struct packed_amscale_t {
uint32_t data[AMSCALES_PACK_SIZE];
};
using amscale_warp = std::array<packed_amscale_t, AMSCALES_NUM_PACKS>;
using wmscale_warp = std::array<packed_wmscale_t, WMSCALES_NUM_PACKS>;
// amscales: [M / BLOCK_M, K / group size, NUM_WARPS, AMSCALES_NUM_PACKS, WARP_SIZE] of packed_amscale_t
__device__ __forceinline__
static void load_amscale(const packed_amscale_t *ptr, int group, amscale_warp &out, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < AMSCALES_NUM_PACKS; i++) {
out[i] = load_pred(&ptr[(group * NUM_WARPS + warpId) * AMSCALES_NUM_PACKS * AMSCALES_VALID_LANES + i * AMSCALES_VALID_LANES + laneId], pred);
}
}
// wmscales: [N / BLOCK_N, 1, K / group size, WMSCALES_NUM_PACKS, WMSCALES_VALID_LANES] of packed_wmscale_t
__device__ __forceinline__
static void load_wmscale(const packed_wmscale_t *ptr, int group, wmscale_warp &out, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int i = 0; i < WMSCALES_NUM_PACKS; i++) {
out[i] = load_pred(&ptr[(group * WMSCALES_NUM_PACKS + i) * WMSCALES_VALID_LANES + laneId], pred);
}
}
__device__ __forceinline__
static void quantize_w4a4_fp4_from_fpsum_warp(const packed_fpsum_t (&fpsum)[INSN_K / INSN_N], packed_act_t &output, uint32_t &output_scale, int ida) {
constexpr int NUM_GROUPS = 4;
static_assert(NUM_GROUPS == INSN_K / INSN_N);
constexpr float QVALUE_MAX = 6.0f;
constexpr float RECPI_QVALUE_MAX = 1 / QVALUE_MAX;
constexpr float MSCALE_MAX = 448.0f;
const int laneId = threadIdx.x % WARP_SIZE;
// 0 for row 0-7; 1 for row 8-15
// each half2_t represents a 8*8 matrix
half2_t input[2][INSN_K / INSN_N * 2];
#pragma unroll
for (int i = 0; i < INSN_K / INSN_N; i++) {
input[0][i * 2 + 0] = fpsum[i].data[0];
input[0][i * 2 + 1] = fpsum[i].data[2];
input[1][i * 2 + 0] = fpsum[i].data[1];
input[1][i * 2 + 1] = fpsum[i].data[3];
}
auto maxabs = [](half2_t val) ALWAYSINLINE {
val = __habs2(val);
return __hmax(val.x, val.y);
};
// each half_t represents maxvalue in a 8*16 matrix
half_t maxvalue[2][NUM_GROUPS];
#pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) {
maxvalue[0][i] = __hmax(maxabs(input[0][i * 2]), maxabs(input[0][i * 2 + 1]));
maxvalue[1][i] = __hmax(maxabs(input[1][i * 2]), maxabs(input[1][i * 2 + 1]));
}
#pragma unroll
for (int mask = 2; mask > 0; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) {
maxvalue[0][i] = __hmax(maxvalue[0][i], __shfl_xor_sync(~0, maxvalue[0][i], mask));
maxvalue[1][i] = __hmax(maxvalue[1][i], __shfl_xor_sync(~0, maxvalue[1][i], mask));
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now
float scale[2][NUM_GROUPS];
float rscale[2][NUM_GROUPS];
#pragma unroll
for (int i = 0; i < NUM_GROUPS; i++) {
scale[0][i] = fminf(float(maxvalue[0][i]) * RECPI_QVALUE_MAX, MSCALE_MAX);
scale[1][i] = fminf(float(maxvalue[1][i]) * RECPI_QVALUE_MAX, MSCALE_MAX);
// TODO: check whether (1 / scale) or (1 / fp8scale) is better
rscale[0][i] = cuda_frcp(scale[0][i]);
rscale[1][i] = cuda_frcp(scale[1][i]);
}
uint32_t fp8scale[2];
fp8scale[0] = quantize_float4_fp8(make_float4(scale[0][0], scale[0][1], scale[0][2], scale[0][3]));
fp8scale[1] = quantize_float4_fp8(make_float4(scale[1][0], scale[1][1], scale[1][2], scale[1][3]));
/**
* output_scale pack format: (ida=0)
* lane 0 => row 0 if ida==0
* lane 1 => row 8 if ida==0
* lane 2 => row 0 if ida==1
* lane 3 => row 8 if ida==1
* ...
* lane i => quad (i/4) => row (i/4+8*(i%2)) if (i%4/2==ida) => srclane i, index i%2
*/
if (laneId % 4 / 2 == ida) {
output_scale = (laneId % 2 == 0) ? fp8scale[0] : fp8scale[1];
}
uint32_t qpacks[2][INSN_K / INSN_M * 2];
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
float2 fval = half22float2(input[j][i]) * make_float2(rscale[j][i / 2], rscale[j][i / 2]);
qpacks[j][i] = quantize_float2_fp4(fval) << (laneId % 4 * 8);
}
}
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
qpacks[j][i] |= __shfl_xor_sync(~0, qpacks[j][i], mask);
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
for (int i = 0; i < 4; i++) {
if (laneId % 4 == i) {
output.x = qpacks[0][0 + i];
output.y = qpacks[1][0 + i];
output.z = qpacks[0][4 + i];
output.w = qpacks[1][4 + i];
}
}
}
// m16n16k64 MMA
// ida, idb in {0, 1}
__device__ __forceinline__
static packed_f32psum_t mma_fp4(packed_act_t act, packed_wgt_t wgt, packed_f32psum_t psum, uint32_t amscale, uint32_t wmscale, int ida, int idb) {
packed_f32psum_t out;
asm volatile (
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"(out.data[0]), "=f"(out.data[1]), "=f"(out.data[2]), "=f"(out.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3]),
"r"(amscale), "n"(0), "h"((short)ida),
"r"(wmscale), "n"(0), "h"((short)(idb * 2))
);
asm volatile (
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"(out.data[4]), "=f"(out.data[5]), "=f"(out.data[6]), "=f"(out.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7]),
"r"(amscale), "n"(0), "h"((short)ida),
"r"(wmscale), "n"(0), "h"((short)(idb * 2 + 1))
);
return out;
}
__device__ __forceinline__
static void compute_fp4(act_warp A, wgt_warp W, amscale_warp amscale, wmscale_warp wmscale, f32psum_warp &psum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
psum[i * WARP_N_TILES + j] = mma_fp4(
A[i], W[j], psum[i * WARP_N_TILES + j],
amscale[i / 2 / AMSCALES_PACK_SIZE].data[i / 2 % AMSCALES_PACK_SIZE],
wmscale[j / 2 / WMSCALES_PACK_SIZE].data[j / 2 % WMSCALES_PACK_SIZE],
i % 2, j % 2
);
}
}
}
template<typename Epilogue, bool USE_ALPHA>
__device__ __forceinline__
static void gemm_w4a4_fp4_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_amscale_t *ascales,
const packed_wmscale_t *wscales,
float alpha, // per-tensor scale of weight
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
act_warp A[NUM_STAGES]; // 8 * 2
wgt_warp W[NUM_STAGES]; // 32 * 2
amscale_warp amscale[NUM_STAGES]; // 1 * 2
wmscale_warp wmscale[NUM_STAGES]; // 4 * 2
f32psum_warp fpsum; // 128
for (int k = 0; k < NUM_STAGES - 1; k++) {
load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true);
load_amscale(ascales, k, amscale[k], true);
load_wmscale(wscales, k, wmscale[k], true);
}
#pragma unroll
for (auto &pack : fpsum) {
#pragma unroll
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
}
}
int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
load_amscale(ascales, nextk, amscale[idx], pred);
load_wmscale(wscales, nextk, wmscale[idx], pred);
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
compute_fp4(A[k2], W[k2], amscale[k2], wmscale[k2], fpsum);
if (alwaysfalse) {
dummy = clock();
}
// asm volatile ("membar.cta;");
}
}
unused_var(dummy, alwaysfalse);
if constexpr (USE_ALPHA) {
#pragma unroll
for (auto &pack : fpsum) {
#pragma unroll
for (int i = 0; i < 8; i++) {
pack.data[i] *= alpha;
}
}
}
auto f16psum = packed_fp32_to_fp16(fpsum);
CHECK_NAN(f16psum, "f16psum");
Epilogue()(binfo, f16psum, M, N, K, epilogueArgs);
}
template<typename Epilogue, bool USE_ALPHA>
struct gemm_w4a4_fp4_kernel {
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_amscale_t *ascales,
const packed_wmscale_t *wscales,
float alpha,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
if (swapBlockXY) {
std::swap(binfo.bm, binfo.bn);
std::swap(binfo.numBlocksM, binfo.numBlocksN);
}
const int bm = binfo.bm;
const int bn = binfo.bn;
if constexpr (FP4_AVAILABLE) {
gemm_w4a4_fp4_block<Epilogue, USE_ALPHA>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (K / WARP_K) * NUM_WARPS * AMSCALES_NUM_PACKS * AMSCALES_VALID_LANES,
wscales + bn * (K / WARP_K) * WMSCALES_NUM_PACKS * WMSCALES_VALID_LANES,
alpha,
M, N, K,
epilogueArgs,
alwaysfalse
);
} else {
trap_no_fp4();
}
}
};
public: public:
template<bool ACT_UNSIGNED> template<bool ACT_UNSIGNED>
__device__ __forceinline__ __device__ __forceinline__
...@@ -416,7 +780,7 @@ public: ...@@ -416,7 +780,7 @@ public:
template<bool ACT_UNSIGNED, typename T> template<bool ACT_UNSIGNED, typename T>
__device__ __forceinline__ __device__ __forceinline__
static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) { static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) {
apply_scales([&](int i, int j) { apply_scales<true>([&](int i, int j) {
return mma<ACT_UNSIGNED>(A[i], W[j]); return mma<ACT_UNSIGNED>(A[i], W[j]);
}, ascale, wscale, fpsum); }, ascale, wscale, fpsum);
} }
...@@ -530,6 +894,10 @@ public: ...@@ -530,6 +894,10 @@ public:
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
#if 0
fpsum_warp fpsum;
GEMM_W4A4_Block<Config>()(act, wgt, ascales, wscales, K, fpsum, alwaysfalse);
#else
act_warp A[NUM_STAGES]; // 8 act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32 wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale[NUM_STAGES]; // 1 ascale_warp ascale[NUM_STAGES]; // 1
...@@ -591,6 +959,8 @@ public: ...@@ -591,6 +959,8 @@ public:
unused_var(dummy, alwaysfalse); unused_var(dummy, alwaysfalse);
#endif
#if 0 #if 0
auto f16psum = packed_fp32_to_fp16(fpsum); auto f16psum = packed_fp32_to_fp16(fpsum);
#else #else
...@@ -602,11 +972,13 @@ public: ...@@ -602,11 +972,13 @@ public:
Epilogue()(binfo, f16psum, M, N, K, epilogueArgs); Epilogue()(binfo, f16psum, M, N, K, epilogueArgs);
} }
template<bool FUSE_GELU, bool USE_UNSIGNED> template<bool FUSE_GELU, bool USE_UNSIGNED, bool USE_FP4>
struct EpilogueQuantize { struct EpilogueQuantize {
using oscales_t = typename std::conditional_t<USE_FP4, packed_amscale_t, packed_ascale_t>;
struct Arguments { struct Arguments {
packed_act_t *qout; packed_act_t *qout;
packed_ascale_t *oscales; oscales_t *oscales;
half_t shift_value; half_t shift_value;
const packed_wscale_t *smooth_factor; const packed_wscale_t *smooth_factor;
...@@ -616,7 +988,7 @@ public: ...@@ -616,7 +988,7 @@ public:
static constexpr int NUM_GROUPS = WARP_N_TILES / NUM_PACKS; static constexpr int NUM_GROUPS = WARP_N_TILES / NUM_PACKS;
__device__ __forceinline__ __device__ __forceinline__
void apply_quantize(fpsum_warp fpsum, int M, int N, int K, packed_act_t *qout, packed_ascale_t *oscales, half_t shift_value, const packed_wscale_t *smooth_factor) { void apply_quantize(fpsum_warp fpsum, int M, int N, int K, packed_act_t *qout, oscales_t *oscales, half_t shift_value, const packed_wscale_t *smooth_factor) {
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
...@@ -627,6 +999,8 @@ public: ...@@ -627,6 +999,8 @@ public:
#pragma unroll #pragma unroll
for (int group = 0; group < NUM_GROUPS; group++) { for (int group = 0; group < NUM_GROUPS; group++) {
amscale_warp omscale;
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) { for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t tmp[NUM_PACKS]; packed_fpsum_t tmp[NUM_PACKS];
...@@ -652,15 +1026,6 @@ public: ...@@ -652,15 +1026,6 @@ public:
// dst = src; // dst = src;
} }
// auto h2div = [](half2_t a, half2_t b) ALWAYSINLINE {
// float2 af = half22float2(a);
// float2 bf = half22float2(b);
// float2 of;
// of.x = __fdividef(af.x, bf.x);
// of.y = __fdividef(af.y, bf.y);
// return float22half2<half2_t>(of);
// };
tmp[j].data[0] = h2div(tmp[j].data[0], ws1); tmp[j].data[0] = h2div(tmp[j].data[0], ws1);
tmp[j].data[1] = h2div(tmp[j].data[1], ws1); tmp[j].data[1] = h2div(tmp[j].data[1], ws1);
tmp[j].data[2] = h2div(tmp[j].data[2], ws2); tmp[j].data[2] = h2div(tmp[j].data[2], ws2);
...@@ -668,13 +1033,26 @@ public: ...@@ -668,13 +1033,26 @@ public:
} }
packed_act_t qresult; packed_act_t qresult;
quantize_w4a4_from_fpsum_warp<USE_UNSIGNED>(tmp, qresult, &oscale_shmem[warpId][i * INSN_M]); if constexpr (USE_FP4) {
quantize_w4a4_fp4_from_fpsum_warp(tmp, qresult, omscale[i / 2 / AMSCALES_PACK_SIZE].data[i / 2 % AMSCALES_PACK_SIZE], i % 2);
} else {
quantize_w4a4_from_fpsum_warp<USE_UNSIGNED>(tmp, qresult, &oscale_shmem[warpId][i * INSN_M]);
}
store(&qout[((group * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], qresult); store(&qout[((group * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], qresult);
} }
__syncwarp(); if constexpr (USE_FP4) {
pack_ascales(&oscale_shmem[warpId][0], &oscales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]); #pragma unroll
__syncwarp(); for (int k = 0; k < AMSCALES_NUM_PACKS; k++) {
store(&oscales[((group * NUM_WARPS + warpId) * AMSCALES_NUM_PACKS + k) * AMSCALES_VALID_LANES + laneId], omscale[k]);
}
}
if constexpr (!USE_FP4) {
__syncwarp();
pack_ascales(&oscale_shmem[warpId][0], &oscales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
__syncwarp();
}
} }
} }
...@@ -683,13 +1061,18 @@ public: ...@@ -683,13 +1061,18 @@ public:
const int bm = binfo.bm; const int bm = binfo.bm;
const int bn = binfo.bn; const int bn = binfo.bn;
apply_quantize( if constexpr (!USE_FP4 || FP4_AVAILABLE) {
fpsum, M, N, K, apply_quantize(
args.qout + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * WARP_M_TILES * WARP_SIZE, fpsum, M, N, K,
args.oscales + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES, args.qout + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
args.shift_value, args.oscales + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS *
args.smooth_factor + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES (USE_FP4 ? AMSCALES_NUM_PACKS * AMSCALES_VALID_LANES : ASCALES_NUM_PACKS * ASCALES_VALID_LANES),
); args.shift_value,
args.smooth_factor + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
} else {
trap_no_fp4();
}
} }
}; };
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>; // using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
...@@ -937,8 +1320,10 @@ public: ...@@ -937,8 +1320,10 @@ public:
} }
}; };
template<bool fuse_glu> template<bool fuse_glu, bool use_fp4>
struct quantize_w4a4_fuse_lora_kernel { struct quantize_w4a4_fuse_lora_kernel {
using oscales_t = typename std::conditional_t<use_fp4, packed_amscale_t, packed_ascale_t>;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128; static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS; static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
...@@ -946,7 +1331,7 @@ public: ...@@ -946,7 +1331,7 @@ public:
const half_t *input; const half_t *input;
const packed_wscale_t *smooth_factor; const packed_wscale_t *smooth_factor;
packed_act_t *output; packed_act_t *output;
packed_ascale_t *oscales; oscales_t *oscales;
const packed_fpsum_t *lora_wgt_down; const packed_fpsum_t *lora_wgt_down;
float *lora_act; float *lora_act;
...@@ -999,7 +1384,7 @@ public: ...@@ -999,7 +1384,7 @@ public:
.lora_act = args.lora_act, .lora_act = args.lora_act,
}); });
EpilogueQuantize<false, false>()(binfo, fpsum, args.M, args.N, 0, typename EpilogueQuantize<false, false>::Arguments{ EpilogueQuantize<false, false, use_fp4>()(binfo, fpsum, args.M, args.N, 0, typename EpilogueQuantize<false, false, use_fp4>::Arguments{
.qout = args.output, .qout = args.output,
.oscales = args.oscales, .oscales = args.oscales,
.shift_value = 0, .shift_value = 0,
...@@ -1488,7 +1873,6 @@ public: ...@@ -1488,7 +1873,6 @@ public:
); );
} }
}; };
}; };
}; // namespace nunchaku::kernels }; // namespace nunchaku::kernels
\ No newline at end of file
...@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch { ...@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
using packed_wgt_t = typename GEMM::packed_wgt_t; using packed_wgt_t = typename GEMM::packed_wgt_t;
using packed_ascale_t = typename GEMM::packed_ascale_t; using packed_ascale_t = typename GEMM::packed_ascale_t;
using packed_wscale_t = typename GEMM::packed_wscale_t; using packed_wscale_t = typename GEMM::packed_wscale_t;
using packed_amscale_t = typename GEMM::packed_amscale_t;
using packed_wmscale_t = typename GEMM::packed_wmscale_t;
using packed_fpsum_t = typename GEMM::packed_fpsum_t; using packed_fpsum_t = typename GEMM::packed_fpsum_t;
using half_t = typename GEMM::half_t; using half_t = typename GEMM::half_t;
...@@ -38,9 +40,12 @@ public: ...@@ -38,9 +40,12 @@ public:
Tensor out_linearattn,// linear [B, (M), N / 3] Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, // [R / 16] std::vector<float> lora_scales, // [R / 16]
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales // packed ws [N]
); );
static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu); static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4);
static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales); static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales); static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
......
...@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
Tensor out_linearattn,// linear [B, (M), N / 3] Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, // [R / 16] std::vector<float> lora_scales, // [R / 16]
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales // packed ws [N]
) { ) {
int M = act.numel() / act.shape[-1]; int M = act.numel() / act.shape[-1];
int N = wgt.shape[0]; int N = wgt.shape[0];
...@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std::swap(grid.x, grid.y); std::swap(grid.x, grid.y);
} }
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() { dispatchBool(fp4, [&]<bool USE_FP4>() {
// test_sizeof<typename Epilogue::Arguments>(); // test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) { // std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...); // (test_sizeof<decltype(args)>(), ...);
// }, args); // }, args);
using kernel = typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>; // constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
auto func = invoke_kernel<kernel, if constexpr (!USE_FP4) {
const packed_act_t *, dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
const packed_wgt_t *, auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
const packed_ascale_t *, const packed_act_t *,
const packed_wscale_t *, const packed_wgt_t *,
int, int, int, const packed_ascale_t *,
typename Epilogue::Arguments, const packed_wscale_t *,
bool, int, int, int,
bool>; typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(alpha == 1.0f);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
}
if (shmem >= 24 * 1024) { if constexpr (USE_FP4) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
assert(!act_unsigned);
auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
const packed_act_t *,
const packed_wgt_t *,
const packed_amscale_t *,
const packed_wmscale_t *,
float,
int, int, int,
typename Epilogue::Arguments,
bool,
bool>;
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
assert(ascales.dtype() == Tensor::FP8_E4M3);
assert(wscales.dtype() == Tensor::FP8_E4M3);
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
act.data_ptr<packed_act_t>(),
wgt.data_ptr<packed_wgt_t>(),
ascales.data_ptr<packed_amscale_t>(),
wscales.data_ptr<packed_wmscale_t>(),
alpha,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
return;
} }
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>( // if constexpr (USE_FP4 && !FP4_AVAILABLE) {
act.data_ptr<packed_act_t>(), // throw std::runtime_error("FP4 kernel is not available");
wgt.data_ptr<packed_wgt_t>(), // }
ascales.data_ptr<packed_ascale_t>(),
wscales.data_ptr<packed_wscale_t>(),
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
}); });
}; };
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) { auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
if (!bias.valid()) { assert(!bias.valid() || bias.numel() == N);
return launch.template operator()<NextEpilogue>(nextArgs); assert(!wcscales.valid() || wcscales.numel() == N);
}
dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
assert(bias.numel() == N); dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device ** // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = typename GEMM::EpilogueCombination<typename GEMM::EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>; using Epilogue = typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({ return launch.template operator()<Epilogue>({
typename GEMM::EpilogueBias::Arguments{ typename EpilogueBias::Arguments{
.bias = bias.data_ptr<packed_wscale_t>(), .bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
}, .scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
nextArgs, },
{} nextArgs,
{}
});
});
}); });
}; };
// auto launch_bias = launch; // auto launch_bias = launch;
...@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4( ...@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static constexpr float SHIFT_GELU = 0.171875f; static constexpr float SHIFT_GELU = 0.171875f;
dispatchBool(fp4, [&]<bool USE_FP4>() {
constexpr bool USE_UNSIGNED = !USE_FP4;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
});
constexpr bool USE_UNSIGNED = true;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<packed_ascale_t>(),
.shift_value = SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
.actualN = actualN,
},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
}
} else if (out_linearattn.valid()) { } else if (out_linearattn.valid()) {
assert(out_vk.valid()); assert(out_vk.valid());
...@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) { ...@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
} }
template<typename Config> template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) { void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
const int actualM = input.numel() / input.shape[-1]; const int actualM = input.numel() / input.shape[-1];
const int actualN = input.shape[-1]; const int actualN = input.shape[-1];
...@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor ...@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
assert(output.shape[-1] == N / 2); assert(output.shape[-1] == N / 2);
// assert(oscales.dtype() == Tensor::FP16); // assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<half_t>(oscales.dtype())); if (fp4) {
assert(oscales.numel() == M * N / GEMM::WARP_K); assert(oscales.dtype() == Tensor::FP8_E4M3);
assert(oscales.numel() == M * N / GEMM::WARP_K * 4);
} else {
assert(isTypeMatch<half_t>(oscales.dtype()));
assert(oscales.numel() == M * N / GEMM::WARP_K);
}
const int rank = lora_down.shape[1]; const int rank = lora_down.shape[1];
...@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor ...@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal(rank, LoraRanks(), [&]<int RANK>() { dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() { dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
using Lora = typename GEMM::Lora<RANK>; dispatchBool(fp4, [&]<bool USE_FP4>() {
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU>; using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
typename kernel::Arguments{ func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
.input = input.data_ptr<half_t>(), typename kernel::Arguments{
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr, .input = input.data_ptr<half_t>(),
.output = output.data_ptr<packed_act_t>(), .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.oscales = oscales.data_ptr<packed_ascale_t>(), .output = output.data_ptr<packed_act_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(), .oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_act = lora_act_out.data_ptr<float>(), .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.M = M, .lora_act = lora_act_out.data_ptr<float>(),
.N = N, .M = M,
.actualM = actualM, .N = N,
.actualN = actualN, .actualM = actualM,
} .actualN = actualN,
); }
checkCUDA(cudaGetLastError()); );
checkCUDA(cudaGetLastError());
});
}); });
}); });
} }
......
...@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K] ...@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device ** // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias, NextEpilogue, GEMM::EpilogueNop>; using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>;
return launch.template operator()<Epilogue>({ return launch.template operator()<Epilogue>({
GEMM::EpilogueBias::Arguments{ GEMM::EpilogueBias<true, false>::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(), .bias = bias.data_ptr<GEMM::packed_wscale_t>(),
}, },
nextArgs, nextArgs,
......
...@@ -27,11 +27,14 @@ void gemm_w4a4( ...@@ -27,11 +27,14 @@ void gemm_w4a4(
Tensor out_linearattn,// linear [B, (M), N / 3] Tensor out_linearattn,// linear [B, (M), N / 3]
bool act_unsigned, bool act_unsigned,
std::vector<float> lora_scales, // [R / 16] std::vector<float> lora_scales, // [R / 16]
bool fuse_silu bool fuse_silu,
bool fp4,
float alpha,
Tensor wcscales
); );
void linearattn_vk_mul_q(Tensor q, Tensor vk); void linearattn_vk_mul_q(Tensor q, Tensor vk);
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false); void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false, bool fp4 = false);
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales); void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales); void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment