Commit e05256c8 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

Flexible lora ranks

parent 618f3078
This diff is collapsed.
......@@ -256,6 +256,16 @@ public:
return results;
}
__device__ __forceinline__
static f32psum_warp packed_fp16_to_fp32(fpsum_warp input) {
f32psum_warp results;
#pragma unroll
for (int i = 0; i < results.size(); i++) {
results[i] = packed_fp16_to_fp32(input[i]);
}
return results;
}
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t
__device__ __forceinline__
static void load_act(const packed_act_t *act, int k, int K, act_warp &out, bool pred) {
......@@ -570,6 +580,63 @@ public:
}
};
// loads act of [WARP_M, WARP_N] and stores to fpsum_warp
// [WARP_M, WARP_N * 2] when fuse_glu
template<bool fuse_glu>
struct load_act_to_fpsum {
using matrix_t = half_t[INSN_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
__device__ __forceinline__
void operator()(const half_t *input, int stride, int maxRows, int maxCols, fpsum_warp &out, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using packed_input = std::array<half_t, PACK_SIZE>;
using packed_raw_input = std::array<half2_t, PACK_SIZE>;
#pragma unroll
for (int m = 0; m < WARP_M_TILES; m++) {
#pragma unroll
for (int row = 0; row < INSN_M; row++) {
packed_input pack;
// TODO: numCols not multiples of PACK_SIZE
if constexpr (fuse_glu) {
packed_raw_input raw;
raw.fill(half2_t(0, 0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE * 2 < maxCols;
if (pred) {
raw = load(reinterpret_cast<const packed_raw_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE * 2));
}
#pragma unroll
for (int j = 0; j < PACK_SIZE; j++) {
pack[j] = raw[j].x * silu(raw[j].y);
}
} else {
pack.fill(half_t(0));
bool pred = (m * INSN_M + row) < maxRows && laneId * PACK_SIZE < maxCols;
if (pred) {
pack = load(reinterpret_cast<const packed_input *>(input + (m * INSN_M + row) * stride + laneId * PACK_SIZE));
}
}
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
__syncwarp();
for (int n = 0; n < WARP_N_TILES; n++) {
const int row = laneId % 16;
const int col = n * INSN_N + laneId / 16 * 8;
uint4 tmp;
ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
}
__syncwarp();
}
}
};
template<typename F>
__device__ __forceinline__
......@@ -599,7 +666,7 @@ public:
};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
......@@ -632,7 +699,7 @@ public:
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, const Arguments &args) {
}
};
......@@ -696,7 +763,7 @@ public:
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bn = binfo.bn;
if constexpr (USE_BIAS || USE_SCALE) {
apply_bias(
......@@ -712,7 +779,7 @@ public:
struct Arguments { size_t unused; };
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
fpsum = apply_act(fpsum, [](half_t x) { return silu(x); });
}
};
......@@ -722,7 +789,7 @@ public:
using Arguments = std::tuple<typename Epilogues::Arguments...>;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
// this function makes intellisense crashes :(
#if __INTELLISENSE__
__trap(); // should not happen when actually compiling
......
......@@ -358,6 +358,14 @@ static void reduce_add(float *addr, float val) {
asm volatile ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "f"(val));
}
__device__ __forceinline__
static void reduce_add_pred(float *addr, float val, bool pred) {
asm volatile (
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}" :: "r"((int)pred), "l"(addr), "f"(val));
}
template<int cnt, typename F>
__device__ __forceinline__
static void unrolled_loop(F &&lambda) {
......
This diff is collapsed.
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
namespace nunchaku::kernels {
template<typename Config, bool USE_FP4>
class GEMM_W4A4_Launch {
using GEMM = GEMM_W4A4<Config>;
// using LoraRanks = std::integer_sequence<int, 0, 32>;
// using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96, 112, 128, 160, 176, 224>;
using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224>;
// using LoraRanks = std::integer_sequence<int,
// 0, 32, 48, 64, 80, 96, 112, 128, 144, 160,
// 176, 192, 208, 224, 240, 256, 272, 288, 304, 320,
// 336, 352, 368, 384, 400, 416, 432, 448, 464, 480,
// 496, 512>;
using Epilogues = Epilogues<Config>;
using Lora = Lora<Config>;
using packed_act_t = typename GEMM::packed_act_t;
using packed_wgt_t = typename GEMM::packed_wgt_t;
......
......@@ -191,76 +191,82 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(lora_up.valid() == lora_act_in.valid());
assert(lora_down.valid() == lora_act_out.valid());
if (!lora_up.valid()) {
assert(!lora_down.valid());
const int rank_up = lora_up.valid() ? lora_up.shape[1] : 0;
const int rank_down = lora_down.valid() ? lora_down.shape[1] : 0;
if (rank_up == 0) {
assert(rank_down == 0);
return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs});
}
const int rank_up = lora_up.shape[1];
assert(rank_up % 16 == 0);
assert(lora_up.shape[0] == N);
// assert(lora_up.shape[1] == Lora::LORA_RANK);
assert(lora_act_in.shape[0] == M);
assert(lora_act_in.shape[1] == rank_up);
dispatchVal(rank_up, LoraRanks(), [&]<int RANK_UP>() {
using LoraUp = typename GEMM::Lora<RANK_UP>;
using scale_t = typename LoraUp::scale_t;
using LoraUp = Lora;
using scale_t = typename LoraUp::scale_t;
scale_t scales;
if constexpr (scales.size() > 0) {
assert(lora_scales.size() >= scales.size());
for (size_t i = 0; i < scales.size(); i++) {
scales[i] = lora_scales[i];
}
}
if (!lora_down.valid()) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.scales = scales,
},
midArgs,
nextArgs,
{}
});
scale_t scales;
if constexpr (scales.size() > 0) {
for (size_t i = 0; i < scales.size(); i++) {
scales[i] = i < lora_scales.size() ? lora_scales[i] : 0.0f;
}
}
const int rank_down = lora_down.shape[1];
assert(rank_down == rank_up);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank_down);
lora_act_out.zero_();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>;
if (rank_down == 0) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.rank = rank_up,
.scales = scales,
.alwaysfalse = false,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
},
nextArgs,
{}
});
}
// });
// assert(rank_down == rank_up);
assert(rank_down % 16 == 0);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank_down);
lora_act_out.zero_();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
.rank = rank_up,
.scales = scales,
.alwaysfalse = false,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.rank = rank_down,
.alwaysfalse = false,
},
nextArgs,
{}
});
// });
};
if (qout.valid() && oscales.valid()) {
......@@ -280,7 +286,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename Epilogues::EpilogueGelu>({
typename GEMM::EpilogueDefault::Arguments{
.out = out.data_ptr<half_t>(),
.actualM = actualM,
......@@ -289,7 +295,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
launch_lora.template operator()<EpilogueQuantize, typename Epilogues::EpilogueGelu>(argsQuantize, {});
}
......@@ -297,7 +303,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(out_vk.valid());
using Epilogue = typename GEMM::EpilogueLiteLA;
using Epilogue = typename Epilogues::EpilogueLiteLA;
assert(out_vk.dtype() == Tensor::FP32);
assert(out_vk.ndims() == 4);
......@@ -334,7 +340,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert(rotary_emb.scalar_type() == Tensor::FP32);
assert(rotary_emb.ndims() == 3);
assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
assert(rotary_emb.shape[2] == GEMM::EpilogueRMSNormRope::HEAD_DIM);
assert(rotary_emb.shape[2] == Epilogues::EpilogueRMSNormRope::HEAD_DIM);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
......@@ -348,8 +354,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// .epsilon = 1e-6,
// }, {});
using EpilogueRope = typename GEMM::EpilogueRMSNormRope;
auto argsRope = typename GEMM::EpilogueRMSNormRope::Arguments{
using EpilogueRope = typename Epilogues::EpilogueRMSNormRope;
auto argsRope = typename Epilogues::EpilogueRMSNormRope::Arguments{
.rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
.rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
......@@ -357,16 +363,16 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
};
if (out_q.valid()) {
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
argsRope,
typename GEMM::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename GEMM::EpiloguePackQKV::packed_qkv_t>(),
typename Epilogues::EpiloguePackQKV::Arguments{
.out_q = out_q.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.out_k = out_k.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.out_v = out_v.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
.actualM = attn_tokens,
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename GEMM::EpiloguePackQKV::packed_qkv_t)),
.strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
.strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
}
}, {});
} else {
......@@ -401,7 +407,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
using Epilogue = typename GEMM::EpilogueLiteLA;
using Epilogue = typename Epilogues::EpilogueLiteLA;
int batch_size = vk.shape[0];
int num_heads = vk.shape[1];
......@@ -449,6 +455,8 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
const int rank = lora_down.shape[1];
assert(rank % 16 == 0);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
......@@ -458,34 +466,36 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
}
);
checkCUDA(cudaGetLastError());
});
// dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
// using Lora = typename GEMM::Lora<RANK>;
using kernel = typename GEMM::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
typename kernel::Arguments{
.input = input.data_ptr<half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
.output = output.data_ptr<packed_act_t>(),
.oscales = oscales.data_ptr<typename kernel::oscales_t>(),
.lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.lora_rank = rank,
.M = M,
.N = N,
.actualM = actualM,
.actualN = actualN,
.alwaysfalse = false,
}
);
checkCUDA(cudaGetLastError());
});
// });
}
template<typename Config, bool USE_FP4>
......
#include "zgemm.h"
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
namespace nunchaku::kernels {
......@@ -10,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
assert(input.shape.dataExtent == output.shape.dataExtent);
assert(input.scalar_type() == Tensor::FP16);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpilogueRMSNormRope;
assert(M % GEMM::BLOCK_M == 0);
......@@ -51,7 +52,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
Tensor output = Tensor::empty_like(input);
using GEMM = GEMM_W4A4<GEMMConfig_W4A4_FP16>;
using GEMM = Epilogues<GEMMConfig_W4A4_FP16>;
using Epilogue = GEMM::EpiloguePackQKV;
assert(M % GEMM::BLOCK_M == 0);
......
#pragma once
#include "gemm_base.cuh"
namespace nunchaku::kernels {
template<typename Config>
class Lora;
#ifndef __INTELLISENSE__
template<typename Config>
class Lora : public GEMMBase<Config> {
#else
template<>
class Lora<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
using Config = GEMMConfig_W4A4_FP16;
#endif
public:
IMPORT_GEMM_BASE(Config);
public:
static constexpr int MAX_RANK = 1024;
static constexpr int WARP_R = 16;
// static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
static constexpr int LORA_R_TILES = WARP_R / 16;
static constexpr int LORA_N_TILES = WARP_N / 16;
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, MAX_RANK / 16>;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
__device__ __forceinline__
static void load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = &ptr[rtile * LORA_R_TILES * WARP_SIZE + laneId];
const int stride_ntile = rank / 16 * WARP_SIZE;
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int roffset = r * WARP_SIZE;
const int noffset = n * stride_ntile;
result[n * LORA_R_TILES + r] = load_pred(ptr_lane + noffset + roffset, pred);
});
});
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static void load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
const float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
constexpr int i = m * LORA_R_TILES + r;
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r];
});
// CHECK_NAN(tmp, "load_lora_act.tmp");
});
});
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId];
unrolled_loop<LORA_M_TILES * LORA_R_TILES>([&]<int i>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
reduce_add_pred(&ptrlane[offset], val[i].data[j], pred);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
const packed_fpsum_t *lora_wgt_up;
int rank;
scale_t scales;
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, const float *act, const packed_fpsum_t *wgt, const scale_t &scales, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act[NUM_STAGES]; // 32
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
int dummy = 0;
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
const bool pred = k == 0 ? true : k < rank / WARP_R;
load_lora_act(act, 0, lora_act[k], pred);
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128
auto compute = [&scales](lora_act_warp A, lora_wgt_warp W, f32psum_warp &f32psum, int rtile) ALWAYSINLINE {
lora_act16_warp A_fp16;
for (int m = 0; m < LORA_M_TILES; m++) {
for (int r = 0; r < LORA_R_TILES; r++) {
packed_f32psum_t pack = A[m * LORA_R_TILES + r];
#pragma unroll
for (int j = 0; j < 8; j++) {
pack.data[j] *= scales[rtile * LORA_R_TILES + r];
}
A_fp16[m * LORA_R_TILES + r] = packed_fp32_to_fp16(pack);
}
}
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]);
}
}
}
};
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
act += kernels::bit_cast<int>(lora_act[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
load_lora_act(act, nextk, lora_act[idx], pred);
load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred);
compute(lora_act[k2], lora_wgt[k2], f32psum, k1 + k2);
}
}
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing (packed_fp16_to_fp32)
// add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
for (auto &&data : lora_act[k]) {
#pragma unroll
for (int i = 0; i < 8; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
fpsum = packed_fp32_to_fp16(f32psum);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum");
apply_lora_up(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.scales,
args.rank,
args.alwaysfalse
);
CHECK_NAN(fpsum, "fpsum");
}
};
struct EpilogueLoraDown {
struct Arguments {
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
int rank;
bool alwaysfalse;
};
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) {
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_wgt_warp lora_wgt[NUM_STAGES]; // 64
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
// we have rank > 0
bool pred = k == 0 ? true : k < rank / WARP_R;
load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred);
}
auto compute = [](lora_wgt_warp W, fpsum_warp fpsum) -> lora_act_warp {
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], W[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
}
return lora_act;
};
int dummy = 0;
for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
if (k1 + k2 >= rank / WARP_R) {
break;
}
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < rank / WARP_R;
if (alwaysfalse) {
wgt += kernels::bit_cast<int>(lora_wgt[k2][0].data[0]);
}
if (alwaysfalse) {
dummy = clock();
}
load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred);
if (alwaysfalse) {
dummy = clock();
}
lora_act_warp lora_act = compute(lora_wgt[k2], fpsum);
reduce_lora_act(act, k1 + k2, lora_act, true);
}
}
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
#pragma unroll
for (auto &&data : lora_wgt[k]) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dummy ^= kernels::bit_cast<int>(data.data[i]);
}
}
}
unused_var(dummy, alwaysfalse);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
apply_lora_down(
fpsum,
args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE,
args.rank,
args.alwaysfalse
);
}
};
};
}; // namespace nunchaku::kernels
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment