Commit e05256c8 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Flexible lora ranks

parent 618f3078
#pragma once
#include "gemm_base.cuh"
namespace nunchaku::kernels {
template<typename Config>
class Epilogues;
#ifndef __INTELLISENSE__
template<typename Config>
class Epilogues : public GEMMBase<Config> {
#else
template<>
class Epilogues<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
using Config = GEMMConfig_W4A4_FP16;
#endif
public:
IMPORT_GEMM_BASE(Config);
public:
struct EpilogueGelu {
struct Arguments { size_t unused; };
// static constexpr float SHIFT_VALUE = 0.171875f;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t &data = fpsum[i * WARP_N_TILES + j].data[k];
data = gelu_half2(data);
// data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE));
}
}
}
}
};
// template<int PoolSize = 128>
struct EpilogueQKVProj {
struct Arguments {
half_t *out;
int actualM, actualN;
half_t *pool_out; // [M / PoolSize, N]
const float *rotary_emb; // [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS]
const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon;
};
static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int PoolSize = 128;
static constexpr int NUM_WARPS_PER_POOL = PoolSize / WARP_M;
static constexpr int NUM_POOLS_PER_BLOCK = BLOCK_M / PoolSize;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
__device__ __forceinline__
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
using pack_t = unpack_fpsum::pack_t;
using pack_rope_t = std::array<float, PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS>;
constexpr int LANES_PER_HEAD = HEAD_DIM / PACK_SIZE;
pack_t reduce_tmp;
__shared__ alignas(128) pack_t pool[NUM_WARPS];
// load rmsnorm scales
pack_t rms;
if (laneId < LANES_PER_HEAD) {
rms = load(reinterpret_cast<const pack_t *>(&rmsnorm_weight[laneId * PACK_SIZE]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < PACK_SIZE; i++) {
rms[i] = __shfl_sync(~0, rms[i], laneId % LANES_PER_HEAD);
}
}
const float *rotary_emb_base_addr = &rotary_emb[(warpId * WARP_M) * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS + laneId * PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS];
CHECK_NAN(fpsum, "fpsum");
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, maxRows - warpId * WARP_M, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
// load rope
pack_rope_t rope;
if (laneId < LANES_PER_HEAD) {
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < rope.size(); i++) {
rope[i] = __shfl_sync(~0, rope[i], laneId % LANES_PER_HEAD);
}
}
// rmsnorm
float sqrsum = 0.0f;
for (int i = 0; i < PACK_SIZE; i++) {
sqrsum += float(pack[i]) * float(pack[i]);
CHECK_NAN(sqrsum, "sqrsum");
}
#pragma unroll
for (int mask = LANES_PER_HEAD / 2; mask > 0; mask /= 2) {
sqrsum += __shfl_xor_sync(~0, sqrsum, mask);
}
sqrsum /= HEAD_DIM;
float coef = cuda_frsqrt(sqrsum + epsilon);
CHECK_NAN(coef, "coef");
for (int i = 0; i < PACK_SIZE; i++) {
pack[i] *= coef * float(rms[i]);
CHECK_NAN(rms[i], "rms.wgt");
CHECK_NAN(pack[i], "rms.out");
}
#if 1
// rope
for (int i = 0; i < PACK_SIZE; i += 2) {
float2 pack2 = half22float2(half2_t(pack[i], pack[i+1]));
CHECK_NAN(freq[i].x, "rope.freq");
CHECK_NAN(freq[i].y, "rope.freq");
CHECK_NAN(freq[i+1].x, "rope.freq");
CHECK_NAN(freq[i+1].y, "rope.freq");
// half2_t tmp = __hmul2(freq[i], pack2);
// tmp = __hfma2(freq[i+1], pack2, tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
// );
// __trap();
// half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
// tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
float sin, cos;
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 1) {
sin = cuda_sin(rope[i / 2]);
cos = cuda_cos(rope[i / 2]);
}
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 2) {
sin = rope[i];
cos = rope[i+1];
}
// pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y;
// pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y;
pack[i] = half_t(pack2.x * cos - pack2.y * sin);
pack[i+1] = half_t(pack2.x * sin + pack2.y * cos);
CHECK_NAN(pack[i], "rope.out");
CHECK_NAN(pack[i+1], "rope.out");
}
#endif
// mean pool
for (int i = 0; i < PACK_SIZE; i++) {
reduce_tmp[i] += pack[i];
}
});
if (!pool_out) {
return;
}
store<true>(&pool[warpId], reduce_tmp);
__syncthreads();
if (warpId < NUM_POOLS_PER_BLOCK) {
const int row = warpId * NUM_WARPS_PER_POOL;
reduce_tmp = load<true>(&pool[row]);
for (int i = 1; i < NUM_WARPS_PER_POOL; i++) {
pack_t pack = load<true>(&pool[row + i]);
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] += pack[j];
}
}
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] /= PoolSize;
}
store(reinterpret_cast<pack_t *>(pool_out + warpId * N), reduce_tmp);
}
__syncthreads();
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
assert(!args.pool_out || args.actualM == M);
assert(args.actualN == N);
if (is_q || is_k) {
apply(
fpsum,
args.out + bm * BLOCK_M * args.actualN + bn * BLOCK_N,
M, N, K,
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon,
args.actualM - bm * BLOCK_M
);
} else {
EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{
.out = args.out,
.actualM = args.actualM,
.actualN = args.actualN,
});
}
}
};
struct EpilogueRMSNormRope {
static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int WARP_N_TILES_PER_HEAD = WARP_N_TILES / NUM_HEADS_PER_WARP;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2;
using packed_rotemb_t = float4;
static constexpr int WARP_N_ROTEMB_TILES = WARP_N_TILES / NUM_HEADS_PER_WARP * 2;
using rotemb_warp = std::array<packed_rotemb_t, WARP_M_TILES * WARP_N_ROTEMB_TILES>; // 128 regs
struct Arguments {
// **packed** [M, HEAD_DIM] float => [M // 16, HEAD_DIM // 8, WARP_SIZE] of packed_rotemb_t
// aka [M // BLOCK_M, NUM_WARPS, WARP_M_TILES, WARP_N_TILES // NUM_HEADS_PER_WARP * 2, WARP_SIZE]
const packed_rotemb_t *rotary_emb;
const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon;
};
__device__ __forceinline__
static rotemb_warp load_rotemb(const packed_rotemb_t *ptr_rotemb) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
rotemb_warp rotemb;
const packed_rotemb_t *ptrlane = &ptr_rotemb[warpId * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE + laneId];
unrolled_loop<WARP_M_TILES>([&]<int i>() {
unrolled_loop<WARP_N_ROTEMB_TILES>([&]<int j>() {
constexpr int offset = (i * WARP_N_ROTEMB_TILES + j) * WARP_SIZE;
rotemb[i * WARP_N_ROTEMB_TILES + j] = load(&ptrlane[offset]);
});
});
return rotemb;
}
__device__ __forceinline__
static void load_rmsnorm(const half_t *ptr_rmsnorm_weight, half_t *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
static constexpr int PACK_SIZE = HEAD_DIM / WARP_SIZE;
using packed_t = std::array<half_t, PACK_SIZE>;
packed_t pack = load(reinterpret_cast<const packed_t *>(ptr_rmsnorm_weight + laneId * PACK_SIZE));
store<true>(reinterpret_cast<packed_t *>(shmem + laneId * PACK_SIZE), pack);
}
__device__ __forceinline__
static packed_fpsum_t load_rmsnorm_from_shmem(half_t *shmem, int n) {
const int laneId = threadIdx.x % WARP_SIZE;
const int col = n * INSN_N + laneId / 16 * 8; // lane 0-15: n*16+0, lane 16-31: n*16+8
uint4 tmp;
ldmatrix(shmem + col, tmp);
return kernels::bit_cast<packed_fpsum_t>(tmp);
}
__device__ __forceinline__
static void apply(fpsum_warp &fpsum, const packed_rotemb_t *ptr_rotemb, const half_t *ptr_rmsnorm_weight, float epsilon) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ half_t shmem_rmsnorm[NUM_WARPS][HEAD_DIM];
load_rmsnorm(ptr_rmsnorm_weight, &shmem_rmsnorm[warpId][0]);
__syncwarp();
rotemb_warp rotemb = load_rotemb(ptr_rotemb);
float rmsnorm_coef[NUM_HEADS_PER_WARP][WARP_M_TILES][2];
auto sqr = [](half2_t val) ALWAYSINLINE {
float2 fval = half22float2(val);
return fval.x * fval.x + fval.y * fval.y;
};
#pragma unroll
for (int head = 0; head < NUM_HEADS_PER_WARP; head++) {
const int n_offset = head * WARP_N_TILES_PER_HEAD;
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
float sqrsum[2] = {0.0f, 0.0f};
#pragma unroll
for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) {
sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[0]);
sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[1]);
sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[2]);
sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[3]);
}
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
sqrsum[0] += __shfl_xor_sync(~0, sqrsum[0], mask);
sqrsum[1] += __shfl_xor_sync(~0, sqrsum[1], mask);
}
rmsnorm_coef[head][m][0] = cuda_frsqrt(sqrsum[0] / HEAD_DIM + epsilon);
rmsnorm_coef[head][m][1] = cuda_frsqrt(sqrsum[1] / HEAD_DIM + epsilon);
}
}
#pragma unroll
for (int head = 0; head < NUM_HEADS_PER_WARP; head++) {
const int n_offset = head * WARP_N_TILES_PER_HEAD;
#pragma unroll
for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) {
packed_f32psum_t rms = packed_fp16_to_fp32(load_rmsnorm_from_shmem(&shmem_rmsnorm[warpId][0], n));
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
packed_f32psum_t pack = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n + n_offset]);
pack.data[0] *= rmsnorm_coef[head][m][0] * rms.data[0];
pack.data[1] *= rmsnorm_coef[head][m][0] * rms.data[1];
pack.data[2] *= rmsnorm_coef[head][m][1] * rms.data[2];
pack.data[3] *= rmsnorm_coef[head][m][1] * rms.data[3];
pack.data[4] *= rmsnorm_coef[head][m][0] * rms.data[4];
pack.data[5] *= rmsnorm_coef[head][m][0] * rms.data[5];
pack.data[6] *= rmsnorm_coef[head][m][1] * rms.data[6];
pack.data[7] *= rmsnorm_coef[head][m][1] * rms.data[7];
auto rope = [](float &x, float &y, float sin, float cos) ALWAYSINLINE {
float ix = x, iy = y;
x = ix * cos - iy * sin;
y = ix * sin + iy * cos;
};
{
packed_rotemb_t sincos = rotemb[m * WARP_N_ROTEMB_TILES + n * 2];
rope(pack.data[0], pack.data[1], sincos.x, sincos.y);
rope(pack.data[2], pack.data[3], sincos.z, sincos.w);
}
{
packed_rotemb_t sincos = rotemb[m * WARP_N_ROTEMB_TILES + n * 2 + 1];
rope(pack.data[4], pack.data[5], sincos.x, sincos.y);
rope(pack.data[6], pack.data[7], sincos.z, sincos.w);
}
fpsum[m * WARP_N_TILES + n + n_offset] = packed_fp32_to_fp16(pack);
}
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
if (is_q || is_k) {
apply(
fpsum,
args.rotary_emb + bm * NUM_WARPS * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE,
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon
);
}
}
};
struct EpiloguePackQKV {
using attn_half_t = half;
using attn_half2_t = half2;
using packed_qkv_t = uint4;
static constexpr int HEAD_DIM = 128;
static constexpr int INSN_K_QK = 16;
static constexpr int INSN_K_PV = 16;
struct Arguments {
packed_qkv_t *out_q, *out_k, *out_v;
int actualM;
// !!! stride in number of packed_qkv_t !!!
int strideHead_q;
int strideHead_k;
int strideHead_v;
};
__device__ __forceinline__
static attn_half2_t convert_half2(half2_t input) {
if constexpr (std::is_same_v<half2_t, attn_half2_t>) {
return input;
} else {
float2 fval = half22float2(input);
return float22half2<attn_half2_t>(fval);
}
}
__device__ __forceinline__
static packed_qkv_t pack_q(packed_fpsum_t input) {
packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = kernels::bit_cast<int>(convert_half2(input.data[1]));
output.z = kernels::bit_cast<int>(convert_half2(input.data[2]));
output.w = kernels::bit_cast<int>(convert_half2(input.data[3]));
return output;
}
__device__ __forceinline__
static packed_qkv_t pack_k(packed_fpsum_t input) {
packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = kernels::bit_cast<int>(convert_half2(input.data[2]));
output.z = kernels::bit_cast<int>(convert_half2(input.data[1]));
output.w = kernels::bit_cast<int>(convert_half2(input.data[3]));
return output;
}
__device__ __forceinline__
static packed_qkv_t pack_v(packed_fpsum_t input) {
packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[0])));
output.y = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[1])));
output.z = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[2])));
output.w = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[3])));
return output;
}
__device__ __forceinline__
static void mask(packed_qkv_t &pack, uint32_t maskVal, int m, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE;
if (m * INSN_M + laneId / 4 >= maxRows) {
pack.x = maskVal;
pack.z = maskVal;
}
if (m * INSN_M + laneId / 4 + 8 >= maxRows) {
pack.y = maskVal;
pack.w = maskVal;
}
}
// qkv: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_N_TILES, WARP_SIZE] of packed_qkv_t
template<typename F>
__device__ __forceinline__
static void apply(fpsum_warp &fpsum, packed_qkv_t *ptr_output, int maxRows, F &&funcPack, attn_half2_t maskVal) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
static_assert(HEAD_DIM == WARP_N);
packed_qkv_t *ptrlane = &ptr_output[((warpId * WARP_M_TILES + 0) * WARP_N_TILES + 0) * WARP_SIZE + laneId];
unrolled_loop<WARP_M_TILES>([&]<int m>() ALWAYSINLINE {
unrolled_loop<WARP_N_TILES>([&]<int n>() ALWAYSINLINE {
packed_qkv_t pack = funcPack(fpsum[m * WARP_N_TILES + n]);
mask(pack, kernels::bit_cast<uint32_t>(maskVal), m, maxRows - warpId * WARP_M);
store(&ptrlane[(m * WARP_N_TILES + n) * WARP_SIZE], pack);
});
});
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const int numBlocksQ = binfo.numBlocksN / 3;
const bool is_q = bn < numBlocksQ;
const bool is_k = !is_q && bn < numBlocksQ * 2;
// bn is head_id (assume HEAD_DIM == WARP_N)
int head_id, strideHead;
if (is_q) {
head_id = bn;
strideHead = args.strideHead_q;
} else if (is_k) {
head_id = bn - numBlocksQ;
strideHead = args.strideHead_k;
} else {
head_id = bn - numBlocksQ * 2;
strideHead = args.strideHead_v;
}
int block_offset = head_id * strideHead + bm * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE;
int maxRows = args.actualM - bm * BLOCK_M;
// static constexpr float neginf = -std::numeric_limits<float>::infinity();
if (is_q) {
apply(fpsum, args.out_q + block_offset, maxRows, pack_q, attn_half2_t(0.0f, 0.0f));
} else if (is_k) {
apply(fpsum, args.out_k + block_offset, maxRows, pack_k, attn_half2_t(NAN, NAN));
} else {
apply(fpsum, args.out_v + block_offset, maxRows, pack_v, attn_half2_t(0.0f, 0.0f));
}
}
};
struct EpilogueLiteLA {
__device__ __forceinline__
static packed_f32psum_t mma_litela(packed_fpsum_t k, packed_fpsum_t v, packed_f32psum_t psum) {
for (int i = 0; i < 4; i++) {
k.data[i] = movmatrix(k.data[i]);
v.data[i] = movmatrix(v.data[i]);
}
std::swap(v.data[1], v.data[2]);
return mma_f16xf16_f32(v, k, psum);
}
static constexpr int LITELA_HEAD_DIM = 32;
static constexpr int LITELA_K_TILES = LITELA_HEAD_DIM / 16;
static constexpr int LITELA_V_TILES = LITELA_HEAD_DIM / 16;
static constexpr int SHMEM_SIZE = NUM_WARPS * (LITELA_HEAD_DIM + 1) * (LITELA_HEAD_DIM + 8) * sizeof(float);
// out_vk: [batch_size, num_heads, head_dim + 1, head_dim]
__device__ __forceinline__
static void apply_litela(const BlockInfo binfo, fpsum_warp fpsum, float *out_vk, int num_blocks_per_batch) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
using vk_t = float[NUM_WARPS][LITELA_HEAD_DIM + 1][LITELA_HEAD_DIM + 8];
extern __shared__ uint8_t shmem[];
vk_t &shmem_vk = *reinterpret_cast<vk_t *>(shmem);
static_assert(sizeof(vk_t) == SHMEM_SIZE);
static_assert(WARP_N == BLOCK_N);
assert(binfo.numBlocksN % 3 == 0);
const int num_heads = binfo.numBlocksN / 3 * 2 * (WARP_N / (LITELA_HEAD_DIM * 2));
const int batch_id = binfo.bm / num_blocks_per_batch;
for (int head_id = 0; head_id < WARP_N / (LITELA_HEAD_DIM * 2); head_id++) {
const int global_head_id = (binfo.bn - binfo.numBlocksN / 3) * (WARP_N / (LITELA_HEAD_DIM * 2)) + head_id;
float *out_vk_current_head = out_vk + (batch_id * num_heads + global_head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM;
for (int i = laneId; i < sizeof(shmem_vk) / sizeof(float) / NUM_WARPS; i += WARP_SIZE) {
*((&shmem_vk[warpId][0][0]) + i) = 0;
}
__syncwarp();
for (int tile_v = 0; tile_v < LITELA_V_TILES; tile_v++) {
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
packed_f32psum_t attn_sum = { 0 };
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + LITELA_HEAD_DIM / 16 + tile_v];
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
}
attn_sum = mma_litela(k, v, attn_sum);
}
const int row = tile_v * 16 + laneId / 4;
const int col = tile_k * 16 + laneId % 4 * 2;
shmem_vk[warpId][row + 0][col + 0] = attn_sum.data[0];
shmem_vk[warpId][row + 0][col + 1] = attn_sum.data[1];
shmem_vk[warpId][row + 8][col + 0] = attn_sum.data[2];
shmem_vk[warpId][row + 8][col + 1] = attn_sum.data[3];
shmem_vk[warpId][row + 0][col + 8] = attn_sum.data[4];
shmem_vk[warpId][row + 0][col + 9] = attn_sum.data[5];
shmem_vk[warpId][row + 8][col + 8] = attn_sum.data[6];
shmem_vk[warpId][row + 8][col + 9] = attn_sum.data[7];
}
}
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
packed_f32psum_t attn_sum = { 0 };
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = {};
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
}
#pragma unroll
for (int i = 0; i < 4; i++) {
v.data[i] = half2_t(1, 1);
}
// if (laneId < 4) {
// v.data[0] = half2_t(1, 1);
// v.data[2] = half2_t(1, 1);
// }
// if (laneId % 4 == 0) {
// v.data[0] = half2_t(1, 0);
// v.data[1] = half2_t(1, 0);
// }
attn_sum = mma_litela(k, v, attn_sum);
}
const int row = LITELA_HEAD_DIM + laneId / 4;
const int col = tile_k * 16 + laneId % 4 * 2;
if (laneId < 4) {
shmem_vk[warpId][row + 0][col + 0] = attn_sum.data[0];
shmem_vk[warpId][row + 0][col + 1] = attn_sum.data[1];
shmem_vk[warpId][row + 0][col + 8] = attn_sum.data[4];
shmem_vk[warpId][row + 0][col + 9] = attn_sum.data[5];
}
}
__syncthreads();
for (int i = warpId; i < LITELA_HEAD_DIM + 1; i += NUM_WARPS) {
for (int j = laneId; j < LITELA_HEAD_DIM; j += WARP_SIZE) {
float sum = 0;
for (int k = 0; k < NUM_WARPS; k++) {
sum += shmem_vk[k][i][j];
}
reduce_add(&out_vk_current_head[i * LITELA_HEAD_DIM + j], sum);
}
}
__syncthreads();
}
}
struct Arguments {
half_t *out_q;
float *out_vk;
int num_blocks_per_batch;
int actualM;
};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
if (bn < binfo.numBlocksN / 3) {
fpsum = apply_act(fpsum, [](half_t x) { return __hmax(x, 0); }); // relu
return EpilogueDefault()(
binfo,
fpsum,
M, N / 3, K, typename EpilogueDefault::Arguments{
.out = args.out_q,
.actualM = args.actualM,
.actualN = N / 3,
});
}
return apply_litela(binfo, fpsum, args.out_vk, args.num_blocks_per_batch);
}
// each thread block mults BlockSize*HEAD_DIM q and (HEAD_DIM+1)*HEAD_DIM vk, in-place writes back to q
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
struct vk_mul_q_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
// FIXME FIXME FIXME
__device__
void operator()(half_t *q, const float *vk, float eps, int num_tokens) {
const int block_id = blockIdx.x;
const int head_id = blockIdx.y;
const int batch_id = blockIdx.z;
const int num_blocks = gridDim.x;
const int num_heads = gridDim.y;
const int block_size = blockDim.x;
bool pred = block_id * block_size + threadIdx.x < num_tokens;
half_t *localq = &q[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
const float *localvk = &vk[(batch_id * num_heads + head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM];
// half_t *localout = &out[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
using packed_q = std::array<half_t, 8>;
using packed_vk = std::array<float, 4>;
half_t qblock[LITELA_HEAD_DIM];
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
if (pred) {
*reinterpret_cast<packed_q *>(&qblock[i]) = load(reinterpret_cast<const packed_q *>(&localq[i]));
}
}
float outblock[LITELA_HEAD_DIM + 1];
#pragma unroll
for (int j = 0; j < LITELA_HEAD_DIM + 1; j++) {
outblock[j] = 0;
#pragma unroll
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_vk) / sizeof(float)) {
packed_vk vkpack = load(reinterpret_cast<const packed_vk *>(&localvk[j * LITELA_HEAD_DIM + i]));
#pragma unroll
for (int k = 0; k < vkpack.size(); k++) {
outblock[j] += (float)qblock[i + k] * vkpack[k];
}
}
}
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
packed_q opack;
for (int k = 0; k < opack.size(); k++) {
opack[k] = __fdividef(outblock[i + k], outblock[LITELA_HEAD_DIM] + eps);
}
if (pred) {
store(reinterpret_cast<packed_q *>(&localq[i]), opack);
}
}
}
};
};
template<typename Epilogue>
struct test_epilogue_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(Base::template load_act_to_fpsum<false>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments {
const half_t *input;
half_t *output;
// aligned to BLOCK_M and BLOCK_N
int M, N;
int actualM, actualN;
typename Epilogue::Arguments argsEpilogue;
};
__device__ __forceinline__
void operator()(Arguments args)
{
const BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
const int bm = binfo.bm;
const int bn = binfo.bn;
const int warpId = threadIdx.x / WARP_SIZE;
const int m_offset = bm * BLOCK_M + warpId * WARP_M;
const int n_offset = bn * BLOCK_N;
extern __shared__ uint8_t shmem[];
fpsum_warp fpsum;
Base::template load_act_to_fpsum<false>()(
args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
fpsum,
shmem + warpId * SHMEM_PER_WARP
);
Epilogue()(binfo, fpsum, args.M, args.N, 0, args.argsEpilogue);
EpilogueDefault()(binfo, fpsum, args.M, args.N, 0, typename EpilogueDefault::Arguments{
.out = args.output,
.actualM = args.actualM,
.actualN = args.actualN,
});
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
......@@ -256,6 +256,16 @@ public:
return results;
}
__device__ __forceinline__
static f32psum_warp packed_fp16_to_fp32(fpsum_warp input) {
f32psum_warp results;
#pragma unroll
for (int i = 0; i < results.size(); i++) {
results[i] = packed_fp16_to_fp32(input[i]);
}
return results;
}
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t
__device__ __forceinline__
static void load_act(const packed_act_t *act, int k, int K, act_warp &out, bool pred) {
......@@ -570,6 +580,63 @@ public:
}
};
// loads act of [WARP_M, WARP_N] and stores to fpsum_warp
// [WARP_M, WARP_N * 2] when fuse_glu
template<bool fuse_glu>
struct load_act_to_fpsum {
using matrix_t = half_t[INSN_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
__device__ __forceinline__
void operator()(const half_t *input, int stride, int maxRows, int maxCols, fpsum_warp &out, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using packed_input = std::array<half_t, PACK_SIZE>;
using packed_raw_input = std::array<half2_t, PACK_SIZE>;
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
#pragma unroll
for (int row = 0; row < INSN_M; row++) {
packed_input pack;
// TODO: numCols not multiples of PACK_SIZE
if constexpr (fuse_glu) {
packed_raw_input raw;
raw.fill(half2_t(0, 0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE * 2 < maxCols;
if (pred) {
raw = load(reinterpret_cast<const packed_raw_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE * 2));
}
#pragma unroll
for (int j = 0; j < PACK_SIZE; j++) {
pack[j] = raw[j].x * silu(raw[j].y);
}
} else {
pack.fill(half_t(0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
pack = load(reinterpret_cast<const packed_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE));
}
}
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
__syncwarp();
for (int n = 0; n < WARP_N_TILES; n++) {
const int row = laneId % 16;
const int col = n * INSN_N + laneId / 16 * 8;
uint4 tmp;
ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
}
__syncwarp();
}
}
};
template<typename F>
__device__ __forceinline__
......@@ -599,7 +666,7 @@ public:
};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
......@@ -632,7 +699,7 @@ public:
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
}
};
......@@ -696,7 +763,7 @@ public:
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bn = binfo.bn;
if constexpr (USE_BIAS || USE_SCALE) {
apply_bias(
......@@ -712,7 +779,7 @@ public:
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
fpsum = apply_act(fpsum, [](half_t x) { return silu(x); });
}
};
......@@ -722,7 +789,7 @@ public:
using Arguments = std::tuple<typename Epilogues::Arguments...>;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
// this function makes intellisense crashes :(
#if __INTELLISENSE__
__trap(); // should not happen when actually compiling
......
......@@ -358,6 +358,14 @@ static void reduce_add(float *addr, float val) {
asm volatile ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "f"(val));
}
__device__ __forceinline__
static void reduce_add_pred(float *addr, float val, bool pred) {
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}" :: "r"((int)pred), "l"(addr), "f"(val));
}
template<int cnt, typename F>
__device__ __forceinline__
static void unrolled_loop(F &&lambda) {
......
#pragma once
#include "gemm_base.cuh"
#include "lora.cuh"
// #include "gemm_w4a4_block.cuh"
namespace nunchaku::kernels {
......@@ -256,7 +257,7 @@ public:
const packed_wmscale_t *wscales,
float alpha, // per-tensor scale of weight
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
const Epilogue::Arguments &epilogueArgs,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
......@@ -500,64 +501,6 @@ public:
}
}
// loads act of [WARP_M, WARP_N] and stores to fpsum_warp
// [WARP_M, WARP_N * 2] when fuse_glu
template<bool fuse_glu>
struct load_act_to_fpsum {
using matrix_t = half_t[INSN_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
__device__ __forceinline__
void operator()(const half_t *input, int stride, int maxRows, int maxCols, fpsum_warp &out, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using packed_input = std::array<half_t, PACK_SIZE>;
using packed_raw_input = std::array<half2_t, PACK_SIZE>;
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
#pragma unroll
for (int row = 0; row < INSN_M; row++) {
packed_input pack;
// TODO: numCols not multiples of PACK_SIZE
if constexpr (fuse_glu) {
packed_raw_input raw;
raw.fill(half2_t(0, 0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE * 2 < maxCols;
if (pred) {
raw = load(reinterpret_cast<const packed_raw_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE * 2));
}
#pragma unroll
for (int j = 0; j < PACK_SIZE; j++) {
pack[j] = raw[j].x * silu(raw[j].y);
}
} else {
pack.fill(half_t(0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
pack = load(reinterpret_cast<const packed_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE));
}
}
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
__syncwarp();
for (int n = 0; n < WARP_N_TILES; n++) {
const int row = laneId % 16;
const int col = n * INSN_N + laneId / 16 * 8;
uint4 tmp;
ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
}
__syncwarp();
}
}
};
/**
* each warp quantizes a INSN_M * INSN_K (16 * 64) matrix
......@@ -883,7 +826,7 @@ public:
// const packed_wscale_t *bias_ptr,
// half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
const Epilogue::Arguments &epilogueArgs,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
......@@ -1057,7 +1000,7 @@ public:
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
......@@ -1077,1045 +1020,6 @@ public:
};
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
template<int rank = 32>
struct Lora {
static_assert(rank % 16 == 0);
static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
static constexpr int LORA_R_TILES = LORA_RANK / 16;
static constexpr int LORA_N_TILES = WARP_N / 16;
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, LORA_R_TILES>;
// lora_wgt: [N / 16, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
__device__ __forceinline__
static lora_wgt_warp load_lora_wgt(const packed_fpsum_t *ptr) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = ptr + laneId;
lora_wgt_warp result;
#if 0
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
result[n * LORA_R_TILES + r] = load(ptr_lane + (n * LORA_R_TILES + r) * WARP_SIZE);
}
}
#else
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int offset = (n * LORA_R_TILES + r) * WARP_SIZE;
result[n * LORA_R_TILES + r] = load(ptr_lane + offset);
});
});
#endif
return result;
}
// lora_act: [M / BLOCK_M, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static lora_act16_warp load_lora_act(const float *ptr, scale_t scales) {
const int laneId = threadIdx.x % WARP_SIZE;
const float *ptrlane = ptr + laneId;
lora_act16_warp result;
#if 0
#pragma unroll
for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
packed_f32psum_t tmp;
#pragma unroll
for (int j = 0; j < 8; j++) {
const int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset];
// tmp.data[j] = ptr[i * 8 * WARP_SIZE + j * WARP_SIZE + laneId];
}
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
}
#else
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
constexpr int i = m * LORA_R_TILES + r;
packed_f32psum_t tmp;
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset] * scales[r];
});
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
});
});
#endif
return result;
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, lora_act_warp val) {
const int laneId = threadIdx.x % WARP_SIZE;
float *ptrlane = ptr + laneId;
// #pragma unroll
// for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
// #pragma unroll
// for (int j = 0; j < 8; j++) {
// int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[i].data[j]);
// }
// }
unrolled_loop<LORA_M_TILES * LORA_R_TILES>([&]<int i>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
reduce_add(&ptrlane[offset], val[i].data[j]);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
const packed_fpsum_t *lora_wgt_up;
scale_t scales;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, int M, int N, int K, const float *act, const packed_fpsum_t *wgt, const scale_t scales, const BlockInfo binfo) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if constexpr (rank > 0) {
lora_act16_warp lora_act = load_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), scales);
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
packed_f32psum_t psum = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n]);
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
psum = mma_f16xf16_f32(lora_act[m * LORA_R_TILES + r], lora_wgt[n * LORA_R_TILES + r], psum);
}
fpsum[m * WARP_N_TILES + n] = packed_fp32_to_fp16(psum);
}
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum");
if constexpr (rank == 0) {
return;
}
apply_lora_up(
fpsum, M, N, K,
args.lora_act + bm * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * LORA_R_TILES * WARP_SIZE,
args.scales,
binfo // for debug
);
CHECK_NAN(fpsum, "fpsum");
}
};
struct EpilogueLoraDown {
struct Arguments {
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
};
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, int M, int N, int K, float *act, const packed_fpsum_t *wgt) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if constexpr (rank > 0) {
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
// clock_t dummy = 0;
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], lora_wgt[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// if (alwaysfalse) {
// dummy = clock();
// }
}
reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act);
// unused_var(dummy, alwaysfalse);
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
if constexpr (rank == 0) {
return;
}
apply_lora_down(
fpsum, M, N, K,
args.lora_act + bm * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * LORA_R_TILES * WARP_SIZE
);
}
};
template<bool fuse_glu, bool use_fp4>
struct quantize_w4a4_fuse_lora_kernel {
using oscales_t = typename std::conditional_t<use_fp4, packed_amscale_t, packed_ascale_t>;
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments {
const half_t *input;
const packed_wscale_t *smooth_factor;
packed_act_t *output;
oscales_t *oscales;
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
// aligned to BLOCK_M and BLOCK_N
int M, N; // N should be the actual K in the next GEMM (needs /2 if fuse_glu)
// the actual M and N (no need to /2 if fuse_glu)
int actualM, actualN;
};
__device__ __forceinline__
void operator()(Arguments args)
{
const BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
const int bm = binfo.bm;
const int bn = binfo.bn;
const int warpId = threadIdx.x / WARP_SIZE;
const int m_offset = bm * BLOCK_M + warpId * WARP_M;
const int n_offset = bn * BLOCK_N * (fuse_glu ? 2 : 1);
extern __shared__ uint8_t shmem[];
fpsum_warp fpsum;
load_act_to_fpsum<fuse_glu>()(
args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
fpsum,
shmem + warpId * SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
CHECK_NAN(fpsum, "fpsum");
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
EpilogueLoraDown()(binfo, fpsum, args.M, args.N, 0, typename EpilogueLoraDown::Arguments{
.lora_wgt_down = args.lora_wgt_down,
.lora_act = args.lora_act,
});
EpilogueQuantize<false, false, use_fp4>()(binfo, fpsum, args.M, args.N, 0, typename EpilogueQuantize<false, false, use_fp4>::Arguments{
.qout = args.output,
.oscales = args.oscales,
.shift_value = 0,
.smooth_factor = args.smooth_factor
});
}
};
};
struct EpilogueGelu {
struct Arguments { size_t unused; };
// static constexpr float SHIFT_VALUE = 0.171875f;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t &data = fpsum[i * WARP_N_TILES + j].data[k];
data = gelu_half2(data);
// data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE));
}
}
}
}
};
// template<int PoolSize = 128>
struct EpilogueQKVProj {
struct Arguments {
half_t *out;
int actualM, actualN;
half_t *pool_out; // [M / PoolSize, N]
const float *rotary_emb; // [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS]
const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon;
};
static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int PoolSize = 128;
static constexpr int NUM_WARPS_PER_POOL = PoolSize / WARP_M;
static constexpr int NUM_POOLS_PER_BLOCK = BLOCK_M / PoolSize;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
__device__ __forceinline__
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
using pack_t = unpack_fpsum::pack_t;
using pack_rope_t = std::array<float, PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS>;
constexpr int LANES_PER_HEAD = HEAD_DIM / PACK_SIZE;
pack_t reduce_tmp;
__shared__ alignas(128) pack_t pool[NUM_WARPS];
// load rmsnorm scales
pack_t rms;
if (laneId < LANES_PER_HEAD) {
rms = load(reinterpret_cast<const pack_t *>(&rmsnorm_weight[laneId * PACK_SIZE]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < PACK_SIZE; i++) {
rms[i] = __shfl_sync(~0, rms[i], laneId % LANES_PER_HEAD);
}
}
const float *rotary_emb_base_addr = &rotary_emb[(warpId * WARP_M) * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS + laneId * PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS];
CHECK_NAN(fpsum, "fpsum");
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, maxRows - warpId * WARP_M, INT_MAX, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
// load rope
pack_rope_t rope;
if (laneId < LANES_PER_HEAD) {
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < rope.size(); i++) {
rope[i] = __shfl_sync(~0, rope[i], laneId % LANES_PER_HEAD);
}
}
// rmsnorm
float sqrsum = 0.0f;
for (int i = 0; i < PACK_SIZE; i++) {
sqrsum += float(pack[i]) * float(pack[i]);
CHECK_NAN(sqrsum, "sqrsum");
}
#pragma unroll
for (int mask = LANES_PER_HEAD / 2; mask > 0; mask /= 2) {
sqrsum += __shfl_xor_sync(~0, sqrsum, mask);
}
sqrsum /= HEAD_DIM;
float coef = cuda_frsqrt(sqrsum + epsilon);
CHECK_NAN(coef, "coef");
for (int i = 0; i < PACK_SIZE; i++) {
pack[i] *= coef * float(rms[i]);
CHECK_NAN(rms[i], "rms.wgt");
CHECK_NAN(pack[i], "rms.out");
}
#if 1
// rope
for (int i = 0; i < PACK_SIZE; i += 2) {
float2 pack2 = half22float2(half2_t(pack[i], pack[i+1]));
CHECK_NAN(freq[i].x, "rope.freq");
CHECK_NAN(freq[i].y, "rope.freq");
CHECK_NAN(freq[i+1].x, "rope.freq");
CHECK_NAN(freq[i+1].y, "rope.freq");
// half2_t tmp = __hmul2(freq[i], pack2);
// tmp = __hfma2(freq[i+1], pack2, tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
// );
// __trap();
// half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
// tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
float sin, cos;
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 1) {
sin = cuda_sin(rope[i / 2]);
cos = cuda_cos(rope[i / 2]);
}
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 2) {
sin = rope[i];
cos = rope[i+1];
}
// pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y;
// pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y;
pack[i] = half_t(pack2.x * cos - pack2.y * sin);
pack[i+1] = half_t(pack2.x * sin + pack2.y * cos);
CHECK_NAN(pack[i], "rope.out");
CHECK_NAN(pack[i+1], "rope.out");
}
#endif
// mean pool
for (int i = 0; i < PACK_SIZE; i++) {
reduce_tmp[i] += pack[i];
}
});
if (!pool_out) {
return;
}
store<true>(&pool[warpId], reduce_tmp);
__syncthreads();
if (warpId < NUM_POOLS_PER_BLOCK) {
const int row = warpId * NUM_WARPS_PER_POOL;
reduce_tmp = load<true>(&pool[row]);
for (int i = 1; i < NUM_WARPS_PER_POOL; i++) {
pack_t pack = load<true>(&pool[row + i]);
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] += pack[j];
}
}
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] /= PoolSize;
}
store(reinterpret_cast<pack_t *>(pool_out + warpId * N), reduce_tmp);
}
__syncthreads();
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
assert(!args.pool_out || args.actualM == M);
assert(args.actualN == N);
if (is_q || is_k) {
apply(
fpsum,
args.out + bm * BLOCK_M * args.actualN + bn * BLOCK_N,
M, N, K,
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon,
args.actualM - bm * BLOCK_M
);
} else {
EpilogueDefault()(binfo, fpsum, M, N, K, typename EpilogueDefault::Arguments{
.out = args.out,
.actualM = args.actualM,
.actualN = args.actualN,
});
}
}
};
struct EpilogueRMSNormRope {
static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int WARP_N_TILES_PER_HEAD = WARP_N_TILES / NUM_HEADS_PER_WARP;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2;
using packed_rotemb_t = float4;
static constexpr int WARP_N_ROTEMB_TILES = WARP_N_TILES / NUM_HEADS_PER_WARP * 2;
using rotemb_warp = std::array<packed_rotemb_t, WARP_M_TILES * WARP_N_ROTEMB_TILES>; // 128 regs
struct Arguments {
// **packed** [M, HEAD_DIM] float => [M // 16, HEAD_DIM // 8, WARP_SIZE] of packed_rotemb_t
// aka [M // BLOCK_M, NUM_WARPS, WARP_M_TILES, WARP_N_TILES // NUM_HEADS_PER_WARP * 2, WARP_SIZE]
const packed_rotemb_t *rotary_emb;
const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon;
};
__device__ __forceinline__
static rotemb_warp load_rotemb(const packed_rotemb_t *ptr_rotemb) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
rotemb_warp rotemb;
const packed_rotemb_t *ptrlane = &ptr_rotemb[warpId * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE + laneId];
unrolled_loop<WARP_M_TILES>([&]<int i>() {
unrolled_loop<WARP_N_ROTEMB_TILES>([&]<int j>() {
constexpr int offset = (i * WARP_N_ROTEMB_TILES + j) * WARP_SIZE;
rotemb[i * WARP_N_ROTEMB_TILES + j] = load(&ptrlane[offset]);
});
});
return rotemb;
}
__device__ __forceinline__
static void load_rmsnorm(const half_t *ptr_rmsnorm_weight, half_t *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
static constexpr int PACK_SIZE = HEAD_DIM / WARP_SIZE;
using packed_t = std::array<half_t, PACK_SIZE>;
packed_t pack = load(reinterpret_cast<const packed_t *>(ptr_rmsnorm_weight + laneId * PACK_SIZE));
store<true>(reinterpret_cast<packed_t *>(shmem + laneId * PACK_SIZE), pack);
}
__device__ __forceinline__
static packed_fpsum_t load_rmsnorm_from_shmem(half_t *shmem, int n) {
const int laneId = threadIdx.x % WARP_SIZE;
const int col = n * INSN_N + laneId / 16 * 8; // lane 0-15: n*16+0, lane 16-31: n*16+8
uint4 tmp;
ldmatrix(shmem + col, tmp);
return kernels::bit_cast<packed_fpsum_t>(tmp);
}
__device__ __forceinline__
static void apply(fpsum_warp &fpsum, const packed_rotemb_t *ptr_rotemb, const half_t *ptr_rmsnorm_weight, float epsilon) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ half_t shmem_rmsnorm[NUM_WARPS][HEAD_DIM];
load_rmsnorm(ptr_rmsnorm_weight, &shmem_rmsnorm[warpId][0]);
__syncwarp();
rotemb_warp rotemb = load_rotemb(ptr_rotemb);
float rmsnorm_coef[NUM_HEADS_PER_WARP][WARP_M_TILES][2];
auto sqr = [](half2_t val) ALWAYSINLINE {
float2 fval = half22float2(val);
return fval.x * fval.x + fval.y * fval.y;
};
#pragma unroll
for (int head = 0; head < NUM_HEADS_PER_WARP; head++) {
const int n_offset = head * WARP_N_TILES_PER_HEAD;
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
float sqrsum[2] = {0.0f, 0.0f};
#pragma unroll
for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) {
sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[0]);
sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[1]);
sqrsum[0] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[2]);
sqrsum[1] += sqr(fpsum[m * WARP_N_TILES + n + n_offset].data[3]);
}
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
sqrsum[0] += __shfl_xor_sync(~0, sqrsum[0], mask);
sqrsum[1] += __shfl_xor_sync(~0, sqrsum[1], mask);
}
rmsnorm_coef[head][m][0] = cuda_frsqrt(sqrsum[0] / HEAD_DIM + epsilon);
rmsnorm_coef[head][m][1] = cuda_frsqrt(sqrsum[1] / HEAD_DIM + epsilon);
}
}
#pragma unroll
for (int head = 0; head < NUM_HEADS_PER_WARP; head++) {
const int n_offset = head * WARP_N_TILES_PER_HEAD;
#pragma unroll
for (int n = 0; n < WARP_N_TILES_PER_HEAD; n++) {
packed_f32psum_t rms = packed_fp16_to_fp32(load_rmsnorm_from_shmem(&shmem_rmsnorm[warpId][0], n));
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
packed_f32psum_t pack = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n + n_offset]);
pack.data[0] *= rmsnorm_coef[head][m][0] * rms.data[0];
pack.data[1] *= rmsnorm_coef[head][m][0] * rms.data[1];
pack.data[2] *= rmsnorm_coef[head][m][1] * rms.data[2];
pack.data[3] *= rmsnorm_coef[head][m][1] * rms.data[3];
pack.data[4] *= rmsnorm_coef[head][m][0] * rms.data[4];
pack.data[5] *= rmsnorm_coef[head][m][0] * rms.data[5];
pack.data[6] *= rmsnorm_coef[head][m][1] * rms.data[6];
pack.data[7] *= rmsnorm_coef[head][m][1] * rms.data[7];
auto rope = [](float &x, float &y, float sin, float cos) ALWAYSINLINE {
float ix = x, iy = y;
x = ix * cos - iy * sin;
y = ix * sin + iy * cos;
};
{
packed_rotemb_t sincos = rotemb[m * WARP_N_ROTEMB_TILES + n * 2];
rope(pack.data[0], pack.data[1], sincos.x, sincos.y);
rope(pack.data[2], pack.data[3], sincos.z, sincos.w);
}
{
packed_rotemb_t sincos = rotemb[m * WARP_N_ROTEMB_TILES + n * 2 + 1];
rope(pack.data[4], pack.data[5], sincos.x, sincos.y);
rope(pack.data[6], pack.data[7], sincos.z, sincos.w);
}
fpsum[m * WARP_N_TILES + n + n_offset] = packed_fp32_to_fp16(pack);
}
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
if (is_q || is_k) {
apply(
fpsum,
args.rotary_emb + bm * NUM_WARPS * WARP_M_TILES * WARP_N_ROTEMB_TILES * WARP_SIZE,
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon
);
}
}
};
struct EpiloguePackQKV {
using attn_half_t = half;
using attn_half2_t = half2;
using packed_qkv_t = uint4;
static constexpr int HEAD_DIM = 128;
static constexpr int INSN_K_QK = 16;
static constexpr int INSN_K_PV = 16;
struct Arguments {
packed_qkv_t *out_q, *out_k, *out_v;
int actualM;
// !!! stride in number of packed_qkv_t !!!
int strideHead_q;
int strideHead_k;
int strideHead_v;
};
__device__ __forceinline__
static attn_half2_t convert_half2(half2_t input) {
if constexpr (std::is_same_v<half2_t, attn_half2_t>) {
return input;
} else {
float2 fval = half22float2(input);
return float22half2<attn_half2_t>(fval);
}
}
__device__ __forceinline__
static packed_qkv_t pack_q(packed_fpsum_t input) {
packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = kernels::bit_cast<int>(convert_half2(input.data[1]));
output.z = kernels::bit_cast<int>(convert_half2(input.data[2]));
output.w = kernels::bit_cast<int>(convert_half2(input.data[3]));
return output;
}
__device__ __forceinline__
static packed_qkv_t pack_k(packed_fpsum_t input) {
packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(input.data[0]));
output.y = kernels::bit_cast<int>(convert_half2(input.data[2]));
output.z = kernels::bit_cast<int>(convert_half2(input.data[1]));
output.w = kernels::bit_cast<int>(convert_half2(input.data[3]));
return output;
}
__device__ __forceinline__
static packed_qkv_t pack_v(packed_fpsum_t input) {
packed_qkv_t output;
output.x = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[0])));
output.y = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[1])));
output.z = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[2])));
output.w = kernels::bit_cast<int>(convert_half2(movmatrix(input.data[3])));
return output;
}
__device__ __forceinline__
static void mask(packed_qkv_t &pack, uint32_t maskVal, int m, int maxRows) {
const int laneId = threadIdx.x % WARP_SIZE;
if (m * INSN_M + laneId / 4 >= maxRows) {
pack.x = maskVal;
pack.z = maskVal;
}
if (m * INSN_M + laneId / 4 + 8 >= maxRows) {
pack.y = maskVal;
pack.w = maskVal;
}
}
// qkv: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_N_TILES, WARP_SIZE] of packed_qkv_t
template<typename F>
__device__ __forceinline__
static void apply(fpsum_warp &fpsum, packed_qkv_t *ptr_output, int maxRows, F &&funcPack, attn_half2_t maskVal) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
static_assert(HEAD_DIM == WARP_N);
packed_qkv_t *ptrlane = &ptr_output[((warpId * WARP_M_TILES + 0) * WARP_N_TILES + 0) * WARP_SIZE + laneId];
unrolled_loop<WARP_M_TILES>([&]<int m>() ALWAYSINLINE {
unrolled_loop<WARP_N_TILES>([&]<int n>() ALWAYSINLINE {
packed_qkv_t pack = funcPack(fpsum[m * WARP_N_TILES + n]);
mask(pack, kernels::bit_cast<uint32_t>(maskVal), m, maxRows - warpId * WARP_M);
store(&ptrlane[(m * WARP_N_TILES + n) * WARP_SIZE], pack);
});
});
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const int numBlocksQ = binfo.numBlocksN / 3;
const bool is_q = bn < numBlocksQ;
const bool is_k = !is_q && bn < numBlocksQ * 2;
// bn is head_id (assume HEAD_DIM == WARP_N)
int head_id, strideHead;
if (is_q) {
head_id = bn;
strideHead = args.strideHead_q;
} else if (is_k) {
head_id = bn - numBlocksQ;
strideHead = args.strideHead_k;
} else {
head_id = bn - numBlocksQ * 2;
strideHead = args.strideHead_v;
}
int block_offset = head_id * strideHead + bm * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE;
int maxRows = args.actualM - bm * BLOCK_M;
// static constexpr float neginf = -std::numeric_limits<float>::infinity();
if (is_q) {
apply(fpsum, args.out_q + block_offset, maxRows, pack_q, attn_half2_t(0.0f, 0.0f));
} else if (is_k) {
apply(fpsum, args.out_k + block_offset, maxRows, pack_k, attn_half2_t(NAN, NAN));
} else {
apply(fpsum, args.out_v + block_offset, maxRows, pack_v, attn_half2_t(0.0f, 0.0f));
}
}
};
struct EpilogueLiteLA {
__device__ __forceinline__
static packed_f32psum_t mma_litela(packed_fpsum_t k, packed_fpsum_t v, packed_f32psum_t psum) {
for (int i = 0; i < 4; i++) {
k.data[i] = movmatrix(k.data[i]);
v.data[i] = movmatrix(v.data[i]);
}
std::swap(v.data[1], v.data[2]);
return mma_f16xf16_f32(v, k, psum);
}
static constexpr int LITELA_HEAD_DIM = 32;
static constexpr int LITELA_K_TILES = LITELA_HEAD_DIM / 16;
static constexpr int LITELA_V_TILES = LITELA_HEAD_DIM / 16;
static constexpr int SHMEM_SIZE = NUM_WARPS * (LITELA_HEAD_DIM + 1) * (LITELA_HEAD_DIM + 8) * sizeof(float);
// out_vk: [batch_size, num_heads, head_dim + 1, head_dim]
__device__ __forceinline__
static void apply_litela(const BlockInfo binfo, fpsum_warp fpsum, float *out_vk, int num_blocks_per_batch) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
using vk_t = float[NUM_WARPS][LITELA_HEAD_DIM + 1][LITELA_HEAD_DIM + 8];
extern __shared__ uint8_t shmem[];
vk_t &shmem_vk = *reinterpret_cast<vk_t *>(shmem);
static_assert(sizeof(vk_t) == SHMEM_SIZE);
static_assert(WARP_N == BLOCK_N);
assert(binfo.numBlocksN % 3 == 0);
const int num_heads = binfo.numBlocksN / 3 * 2 * (WARP_N / (LITELA_HEAD_DIM * 2));
const int batch_id = binfo.bm / num_blocks_per_batch;
for (int head_id = 0; head_id < WARP_N / (LITELA_HEAD_DIM * 2); head_id++) {
const int global_head_id = (binfo.bn - binfo.numBlocksN / 3) * (WARP_N / (LITELA_HEAD_DIM * 2)) + head_id;
float *out_vk_current_head = out_vk + (batch_id * num_heads + global_head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM;
for (int i = laneId; i < sizeof(shmem_vk) / sizeof(float) / NUM_WARPS; i += WARP_SIZE) {
*((&shmem_vk[warpId][0][0]) + i) = 0;
}
__syncwarp();
for (int tile_v = 0; tile_v < LITELA_V_TILES; tile_v++) {
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
packed_f32psum_t attn_sum = { 0 };
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + LITELA_HEAD_DIM / 16 + tile_v];
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
}
attn_sum = mma_litela(k, v, attn_sum);
}
const int row = tile_v * 16 + laneId / 4;
const int col = tile_k * 16 + laneId % 4 * 2;
shmem_vk[warpId][row + 0][col + 0] = attn_sum.data[0];
shmem_vk[warpId][row + 0][col + 1] = attn_sum.data[1];
shmem_vk[warpId][row + 8][col + 0] = attn_sum.data[2];
shmem_vk[warpId][row + 8][col + 1] = attn_sum.data[3];
shmem_vk[warpId][row + 0][col + 8] = attn_sum.data[4];
shmem_vk[warpId][row + 0][col + 9] = attn_sum.data[5];
shmem_vk[warpId][row + 8][col + 8] = attn_sum.data[6];
shmem_vk[warpId][row + 8][col + 9] = attn_sum.data[7];
}
}
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
packed_f32psum_t attn_sum = { 0 };
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = {};
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], half2_t(0, 0)); // relu
}
#pragma unroll
for (int i = 0; i < 4; i++) {
v.data[i] = half2_t(1, 1);
}
// if (laneId < 4) {
// v.data[0] = half2_t(1, 1);
// v.data[2] = half2_t(1, 1);
// }
// if (laneId % 4 == 0) {
// v.data[0] = half2_t(1, 0);
// v.data[1] = half2_t(1, 0);
// }
attn_sum = mma_litela(k, v, attn_sum);
}
const int row = LITELA_HEAD_DIM + laneId / 4;
const int col = tile_k * 16 + laneId % 4 * 2;
if (laneId < 4) {
shmem_vk[warpId][row + 0][col + 0] = attn_sum.data[0];
shmem_vk[warpId][row + 0][col + 1] = attn_sum.data[1];
shmem_vk[warpId][row + 0][col + 8] = attn_sum.data[4];
shmem_vk[warpId][row + 0][col + 9] = attn_sum.data[5];
}
}
__syncthreads();
for (int i = warpId; i < LITELA_HEAD_DIM + 1; i += NUM_WARPS) {
for (int j = laneId; j < LITELA_HEAD_DIM; j += WARP_SIZE) {
float sum = 0;
for (int k = 0; k < NUM_WARPS; k++) {
sum += shmem_vk[k][i][j];
}
reduce_add(&out_vk_current_head[i * LITELA_HEAD_DIM + j], sum);
}
}
__syncthreads();
}
}
struct Arguments {
half_t *out_q;
float *out_vk;
int num_blocks_per_batch;
int actualM;
};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
if (bn < binfo.numBlocksN / 3) {
fpsum = apply_act(fpsum, [](half_t x) { return __hmax(x, 0); }); // relu
return EpilogueDefault()(
binfo,
fpsum,
M, N / 3, K, typename EpilogueDefault::Arguments{
.out = args.out_q,
.actualM = args.actualM,
.actualN = N / 3,
});
}
return apply_litela(binfo, fpsum, args.out_vk, args.num_blocks_per_batch);
}
// each thread block mults BlockSize*HEAD_DIM q and (HEAD_DIM+1)*HEAD_DIM vk, in-place writes back to q
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
struct vk_mul_q_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
// FIXME FIXME FIXME
__device__
void operator()(half_t *q, const float *vk, float eps, int num_tokens) {
const int block_id = blockIdx.x;
const int head_id = blockIdx.y;
const int batch_id = blockIdx.z;
const int num_blocks = gridDim.x;
const int num_heads = gridDim.y;
const int block_size = blockDim.x;
bool pred = block_id * block_size + threadIdx.x < num_tokens;
half_t *localq = &q[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
const float *localvk = &vk[(batch_id * num_heads + head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM];
// half_t *localout = &out[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
using packed_q = std::array<half_t, 8>;
using packed_vk = std::array<float, 4>;
half_t qblock[LITELA_HEAD_DIM];
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
if (pred) {
*reinterpret_cast<packed_q *>(&qblock[i]) = load(reinterpret_cast<const packed_q *>(&localq[i]));
}
}
float outblock[LITELA_HEAD_DIM + 1];
#pragma unroll
for (int j = 0; j < LITELA_HEAD_DIM + 1; j++) {
outblock[j] = 0;
#pragma unroll
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_vk) / sizeof(float)) {
packed_vk vkpack = load(reinterpret_cast<const packed_vk *>(&localvk[j * LITELA_HEAD_DIM + i]));
#pragma unroll
for (int k = 0; k < vkpack.size(); k++) {
outblock[j] += (float)qblock[i + k] * vkpack[k];
}
}
}
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
packed_q opack;
for (int k = 0; k < opack.size(); k++) {
opack[k] = __fdividef(outblock[i + k], outblock[LITELA_HEAD_DIM] + eps);
}
if (pred) {
store(reinterpret_cast<packed_q *>(&localq[i]), opack);
}
}
}
};
};
template<typename Epilogue, bool ACT_UNSIGNED>
struct gemm_w4a4_kernel {
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
......@@ -2167,21 +1071,31 @@ public:
}
};
template<typename Epilogue>
struct test_epilogue_kernel {
template<bool fuse_glu, bool use_fp4>
struct quantize_w4a4_fuse_lora_kernel {
using oscales_t = typename std::conditional_t<use_fp4, packed_amscale_t, packed_ascale_t>;
static constexpr int MIN_ARCH = std::is_same_v<half_t, __nv_bfloat16> ? 800 : 750;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum<false>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(Base::template load_act_to_fpsum<fuse_glu>::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments {
const half_t *input;
half_t *output;
const packed_wscale_t *smooth_factor;
packed_act_t *output;
oscales_t *oscales;
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
int lora_rank;
// aligned to BLOCK_M and BLOCK_N
int M, N;
int M, N; // N should be the actual K in the next GEMM (needs /2 if fuse_glu)
// the actual M and N (no need to /2 if fuse_glu)
int actualM, actualN;
typename Epilogue::Arguments argsEpilogue;
bool alwaysfalse;
};
__device__ __forceinline__
......@@ -2199,30 +1113,48 @@ public:
const int warpId = threadIdx.x / WARP_SIZE;
const int m_offset = bm * BLOCK_M + warpId * WARP_M;
const int n_offset = bn * BLOCK_N;
const int n_offset = bn * BLOCK_N * (fuse_glu ? 2 : 1);
extern __shared__ uint8_t shmem[];
fpsum_warp fpsum;
load_act_to_fpsum<false>()(
Base::template load_act_to_fpsum<fuse_glu>()(
args.input + m_offset * args.actualN + n_offset,
args.actualN,
args.actualM - m_offset,
args.actualN - n_offset,
fpsum,
shmem + warpId * SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
Epilogue()(binfo, fpsum, args.M, args.N, 0, args.argsEpilogue);
CHECK_NAN(fpsum, "fpsum");
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
using EpilogueLoraDown = typename Lora<Config>::EpilogueLoraDown;
EpilogueLoraDown()(binfo, fpsum, args.M, args.N, 0, typename EpilogueLoraDown::Arguments{
.lora_wgt_down = args.lora_wgt_down,
.lora_act = args.lora_act,
.rank = args.lora_rank,
.alwaysfalse = args.alwaysfalse,
});
EpilogueDefault()(binfo, fpsum, args.M, args.N, 0, typename EpilogueDefault::Arguments{
.out = args.output,
.actualM = args.actualM,
.actualN = args.actualN,
EpilogueQuantize<false, false, use_fp4>()(binfo, fpsum, args.M, args.N, 0, typename EpilogueQuantize<false, false, use_fp4>::Arguments{
.qout = args.output,
.oscales = args.oscales,
.shift_value = 0,
.smooth_factor = args.smooth_factor
});
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
namespace nunchaku::kernels {
template<typename Config, bool USE_FP4>
class GEMM_W4A4_Launch {
using GEMM = GEMM_W4A4<Config>;
// using LoraRanks = std::integer_sequence<int, 0, 32>;
// using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96, 112, 128, 160, 176, 224>;
using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224>;
// using LoraRanks = std::integer_sequence<int,
// 0, 32, 48, 64, 80, 96, 112, 128, 144, 160,
// 176, 192, 208, 224, 240, 256, 272, 288, 304, 320,
// 336, 352, 368, 384, 400, 416, 432, 448, 464, 480,
// 496, 512>;
using Epilogues = Epilogues<Config>;
using Lora = Lora<Config>;
using packed_act_t = typename GEMM::packed_act_t;
using packed_wgt_t = typename GEMM::packed_wgt_t;
......
......@@ -191,76 +191,82 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(lora_up.valid() == lora_act_in.valid());
assert(lora_down.valid() == lora_act_out.valid());
if (!lora_up.valid()) {
assert(!lora_down.valid());
const int rank_up = lora_up.valid() ? lora_up.shape[1] : 0;
const int rank_down = lora_down.valid() ? lora_down.shape[1] : 0;
if (rank_up == 0) {
assert(rank_down == 0);
return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs});
}
const int rank_up = lora_up.shape[1];
assert(rank_up % 16 == 0);
assert(lora_up.shape[0] == N);
// assert(lora_up.shape[1] == Lora::LORA_RANK);
assert(lora_act_in.shape[0] == M);
assert(lora_act_in.shape[1] == rank_up);
dispatchVal(rank_up, LoraRanks(), [&]<int RANK_UP>() {
using LoraUp = typename GEMM::Lora<RANK_UP>;
using scale_t = typename LoraUp::scale_t;
using LoraUp = Lora;
using scale_t = typename LoraUp::scale_t;
scale_t scales;
if constexpr (scales.size() > 0) {
assert(lora_scales.size() >= scales.size());
for (size_t i = 0; i < scales.size(); i++) {
scales[i] = lora_scales[i];
}
}
if (!lora_down.valid()) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.scales = scales,
},
midArgs,
nextArgs,
{}
});
scale_t scales;
if constexpr (scales.size() > 0) {
for (size_t i = 0; i < scales.size(); i++) {
scales[i] = i < lora_scales.size() ? lora_scales[i] : 0.0f;
}
}
const int rank_down = lora_down.shape[1];
assert(rank_down == rank_up);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank_down);
lora_act_out.zero_();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>;
if (rank_down == 0) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.rank = rank_up,
.scales = scales,
.alwaysfalse = false,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
},
nextArgs,
{}
});
}
// });
// assert(rank_down == rank_up);
assert(rank_down % 16 == 0);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank_down);
lora_act_out.zero_();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.rank = rank_up,
.scales = scales,
.alwaysfalse = false,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.rank = rank_down,
.alwaysfalse = false,
},
nextArgs,
{}
});
// });
};
if (qout.valid() && oscales.valid()) {
......@@ -280,7 +286,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename Epilogues::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
......@@ -289,7 +295,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
launch_lora.template operator()<EpilogueQuantize, typename Epilogues::EpilogueGelu>(argsQuantize, {});
}
......@@ -297,7 +303,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(out_vk.valid());
using Epilogue = typename GEMM::EpilogueLiteLA;
using Epilogue = typename Epilogues::EpilogueLiteLA;
assert(out_vk.dtype() == Tensor::FP32);
assert(out_vk.ndims() == 4);
......@@ -334,7 +340,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(rotary_emb.scalar_type() == Tensor::FP32);
assert(rotary_emb.ndims() == 3);
assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
assert(rotary_emb.shape[2] == GEMM::EpilogueRMSNormRope::HEAD_DIM);
assert(rotary_emb.shape[2] == Epilogues::EpilogueRMSNormRope::HEAD_DIM);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
......@@ -348,8 +354,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// .epsilon = 1e-6,
// }, {});
using EpilogueRope = typename GEMM::EpilogueRMSNormRope;
auto argsRope = typename GEMM::EpilogueRMSNormRope::Arguments{
using EpilogueRope = typename Epilogues::EpilogueRMSNormRope;
auto argsRope = typename Epilogues::EpilogueRMSNormRope::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
......@@ -357,16 +363,16 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
};
if (out_q.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
typename Epilogues::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.actualM = attn_tokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
}
}, {});
} else {
......@@ -401,7 +407,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
using Epilogue = typename GEMM::EpilogueLiteLA;
using Epilogue = typename Epilogues::EpilogueLiteLA;
int batch_size = vk.shape[0];
int num_heads = vk.shape[1];
......@@ -449,6 +455,8 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
const int rank = lora_down.shape[1];
assert(rank % 16 == 0);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
......@@ -458,34 +466,36 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
}
);
checkCUDA(cudaGetLastError());
});
// dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
// using Lora = typename GEMM::Lora<RANK>;
using kernel = typename GEMM::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.lora_rank = rank,
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
.alwaysfalse = false,
}
);
checkCUDA(cudaGetLastError());
});
// });
}
template<typename Config, bool USE_FP4>
......
#include "zgemm.h"
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
namespace nunchaku::kernels {
......@@ -10,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0);
......@@ -51,7 +52,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
Tensor output = Tensor::empty_like(input);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0);
......
#pragma once
#include "gemm_base.cuh"
namespace nunchaku::kernels {
template<typename Config>
class Lora;
#ifndef __INTELLISENSE__
template<typename Config>
class Lora : public GEMMBase<Config> {
#else
template<>
class Lora<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
using Config = GEMMConfig_W4A4_FP16;
#endif
public:
IMPORT_GEMM_BASE(Config);
public:
static constexpr int MAX_RANK = 1024;
static constexpr int WARP_R = 16;
// static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
static constexpr int LORA_R_TILES = WARP_R / 16;
static constexpr int LORA_N_TILES = WARP_N / 16;
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, MAX_RANK / 16>;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
__device__ __forceinline__
static void load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = &ptr[rtile * LORA_R_TILES * WARP_SIZE + laneId];
const int stride_ntile = rank / 16 * WARP_SIZE;
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int roffset = r * WARP_SIZE;
const int noffset = n * stride_ntile;
result[n * LORA_R_TILES + r] = load_pred(ptr_lane + noffset + roffset, pred);
});
});
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static void load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
const float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
constexpr int i = m * LORA_R_TILES + r;
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r];
});
// CHECK_NAN(tmp, "load_lora_act.tmp");
});
});
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES * LORA_R_TILES>([&]<int i>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
reduce_add_pred(&ptrlane[offset], val[i].data[j], pred);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
const packed_fpsum_t *lora_wgt_up;
int rank;
scale_t scales;
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, const float *act, const packed_fpsum_t *wgt, const scale_t &scales, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act[NUM_STAGES]; // 32
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
int dummy = 0;
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
const bool pred = k == 0 ? true : k < rank / WARP_R;
load_lora_act(act, 0, lora_act[k], pred);
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128
auto compute = [&scales](lora_act_warp A, lora_wgt_warp W, f32psum_warp &f32psum, int rtile) ALWAYSINLINE {
lora_act16_warp A_fp16;
for (int m = 0; m < LORA_M_TILES; m++) {
for (int r = 0; r < LORA_R_TILES; r++) {
packed_f32psum_t pack = A[m * LORA_R_TILES + r];
#pragma unroll
for (int j = 0; j < 8; j++) {
pack.data[j] *= scales[rtile * LORA_R_TILES + r];
}
A_fp16[m * LORA_R_TILES + r] = packed_fp32_to_fp16(pack);
}
}
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]);
}
}
}
};
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
act += kernels::bit_cast<int>(lora_act[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
load_lora_act(act, nextk, lora_act[idx], pred);
load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred);
compute(lora_act[k2], lora_wgt[k2], f32psum, k1 + k2);
}
}
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing (packed_fp16_to_fp32)
// add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
for (auto &&data : lora_act[k]) {
#pragma unroll
for (int i = 0; i < 8; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
fpsum = packed_fp32_to_fp16(f32psum);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum");
apply_lora_up(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.scales,
args.rank,
args.alwaysfalse
);
CHECK_NAN(fpsum, "fpsum");
}
};
struct EpilogueLoraDown {
struct Arguments {
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
int rank;
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
bool pred = k == 0 ? true : k < rank / WARP_R;
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
auto compute = [](lora_wgt_warp W, fpsum_warp fpsum) -> lora_act_warp {
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], W[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
}
return lora_act;
};
int dummy = 0;
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
wgt += kernels::bit_cast<int>(lora_wgt[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred);
if (alwaysfalse) {
dummy = clock();
}
lora_act_warp lora_act = compute(lora_wgt[k2], fpsum);
reduce_lora_act(act, k1 + k2, lora_act, true);
}
}
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
apply_lora_down(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.rank,
args.alwaysfalse
);
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment