#pragma once #include "gemm_base.cuh" namespace nunchaku::kernels { template class Lora; #ifndef __INTELLISENSE__ template class Lora : public GEMMBase { #else template<> class Lora : public GEMMBase { using Config = GEMMConfig_W4A4_FP16; #endif public: IMPORT_GEMM_BASE(Config); public: static constexpr int MAX_RANK = 1024; static constexpr int WARP_R = 16; // static constexpr int LORA_RANK = rank; static constexpr int LORA_M_TILES = WARP_M / 16; static constexpr int LORA_R_TILES = WARP_R / 16; static constexpr int LORA_N_TILES = WARP_N / 16; static_assert(LORA_M_TILES == WARP_M_TILES); static_assert(LORA_N_TILES == WARP_N_TILES); // lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R] // lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N] // we use fp32 for lora activation since there's no bf16 reduction in sm_89 :( using lora_act_warp = std::array; using lora_act16_warp = std::array; using lora_wgt_warp = std::array; using scale_t = std::array; // lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t // [N / 16, rank / 16, WARP_SIZE] __device__ __forceinline__ static void load_lora_wgt(const packed_fpsum_t *ptr, int rtile, int rank, lora_wgt_warp &result, bool pred) { const int laneId = threadIdx.x % WARP_SIZE; const packed_fpsum_t *ptr_lane = &ptr[rtile * LORA_R_TILES * WARP_SIZE + laneId]; const int stride_ntile = rank / 16 * WARP_SIZE; unrolled_loop([&]() { unrolled_loop([&]() { constexpr int roffset = r * WARP_SIZE; const int noffset = n * stride_ntile; result[n * LORA_R_TILES + r] = load_pred(ptr_lane + noffset + roffset, pred); }); }); } // lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float __device__ __forceinline__ static void load_lora_act(const float *ptr, int rtile, lora_act_warp &result, bool pred) { const int laneId = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; const float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId]; unrolled_loop([&]() { unrolled_loop([&]{ constexpr int i = m * LORA_R_TILES + r; unrolled_loop<8>([&]() { constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE; result[i].data[j] = load_pred(ptrlane + offset, pred); // * scales[rtile * LORA_R_TILES + r]; }); // CHECK_NAN(tmp, "load_lora_act.tmp"); }); }); } // no vector reduction in sm_89 :( __device__ __forceinline__ static void reduce_lora_act(float *ptr, int rtile, lora_act_warp val, bool pred) { const int laneId = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; float *ptrlane = &ptr[(rtile * NUM_WARPS + warpId) * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE + laneId]; unrolled_loop([&]() { unrolled_loop<8>([&]() { constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE; reduce_add_pred(&ptrlane[offset], val[i].data[j], pred); }); }); } // __device__ __forceinline__ // static void reduce_lora_act(float *ptr, lora_act_warp val, int m) { // const int laneId = threadIdx.x % WARP_SIZE; // float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE; // unrolled_loop([&]() { // unrolled_loop<8>([&]() { // constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE; // reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]); // }); // }); // } struct EpilogueLoraUp { struct Arguments { const float *lora_act; const packed_fpsum_t *lora_wgt_up; int rank; scale_t scales; bool alwaysfalse; }; __device__ __forceinline__ static void apply_lora_up(fpsum_warp &fpsum, const float *act, const packed_fpsum_t *wgt, const scale_t &scales, int rank, bool alwaysfalse) { constexpr int NUM_STAGES = 2; const int laneId = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; lora_act_warp lora_act[NUM_STAGES]; // 32 lora_wgt_warp lora_wgt[NUM_STAGES]; // 64 int dummy = 0; #pragma unroll for (int k = 0; k < NUM_STAGES - 1; k++) { // we have rank > 0 const bool pred = k == 0 ? true : k < rank / WARP_R; load_lora_act(act, 0, lora_act[k], pred); load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred); } f32psum_warp f32psum = packed_fp16_to_fp32(fpsum); // 128 auto compute = [&scales](lora_act_warp A, lora_wgt_warp W, f32psum_warp &f32psum, int rtile) ALWAYSINLINE { lora_act16_warp A_fp16; for (int m = 0; m < LORA_M_TILES; m++) { for (int r = 0; r < LORA_R_TILES; r++) { packed_f32psum_t pack = A[m * LORA_R_TILES + r]; #pragma unroll for (int j = 0; j < 8; j++) { pack.data[j] *= scales[rtile * LORA_R_TILES + r]; } A_fp16[m * LORA_R_TILES + r] = packed_fp32_to_fp16(pack); } } for (int m = 0; m < LORA_M_TILES; m++) { for (int n = 0; n < LORA_N_TILES; n++) { for (int r = 0; r < LORA_R_TILES; r++) { CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act"); CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt"); f32psum[m * WARP_N_TILES + n] = mma_f16xf16_f32(A_fp16[m * LORA_R_TILES + r], W[n * LORA_R_TILES + r], f32psum[m * WARP_N_TILES + n]); } } } }; for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) { #pragma unroll for (int k2 = 0; k2 < NUM_STAGES; k2++) { if (k1 + k2 >= rank / WARP_R) { break; } int nextk = k1 + k2 + NUM_STAGES - 1; int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES; bool pred = nextk < rank / WARP_R; if (alwaysfalse) { act += kernels::bit_cast(lora_act[k2][0].data[0]); } if (alwaysfalse) { dummy = clock(); } load_lora_act(act, nextk, lora_act[idx], pred); load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred); compute(lora_act[k2], lora_wgt[k2], f32psum, k1 + k2); } } // NVCC does not know rank > 0 :( // it will generate a branch instruction to skip the initial load // the branch splits the basic blocks and prevents the overlap of memory access and computing (packed_fp16_to_fp32) // add fake dependency of loaded data so NVCC will not skip the load #pragma unroll for (int k = 0; k < NUM_STAGES - 1; k++) { #pragma unroll for (auto &&data : lora_act[k]) { #pragma unroll for (int i = 0; i < 8; i++) { dummy ^= kernels::bit_cast(data.data[i]); } } #pragma unroll for (auto &&data : lora_wgt[k]) { #pragma unroll for (int i = 0; i < 4; i++) { dummy ^= kernels::bit_cast(data.data[i]); } } } unused_var(dummy, alwaysfalse); fpsum = packed_fp32_to_fp16(f32psum); } __device__ __forceinline__ void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { const int bm = binfo.bm; const int bn = binfo.bn; CHECK_NAN(fpsum, "fpsum"); apply_lora_up( fpsum, args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), args.lora_wgt_up + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE, args.scales, args.rank, args.alwaysfalse ); CHECK_NAN(fpsum, "fpsum"); } }; struct EpilogueLoraDown { struct Arguments { const packed_fpsum_t *lora_wgt_down; float *lora_act; int rank; bool alwaysfalse; }; __device__ __forceinline__ static void apply_lora_down(fpsum_warp &fpsum, float *act, const packed_fpsum_t *wgt, int rank, bool alwaysfalse) { constexpr int NUM_STAGES = 2; const int laneId = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; lora_wgt_warp lora_wgt[NUM_STAGES]; // 64 #pragma unroll for (int k = 0; k < NUM_STAGES - 1; k++) { // we have rank > 0 bool pred = k == 0 ? true : k < rank / WARP_R; load_lora_wgt(wgt, 0, rank, lora_wgt[k], pred); } auto compute = [](lora_wgt_warp W, fpsum_warp fpsum) -> lora_act_warp { lora_act_warp lora_act; lora_act.fill(packed_f32psum_t::zeros()); #pragma unroll for (int m = 0; m < LORA_M_TILES; m++) { #pragma unroll for (int n = 0; n < LORA_N_TILES; n++) { #pragma unroll for (int r = 0; r < LORA_R_TILES; r++) { auto &psum = lora_act[m * LORA_R_TILES + r]; CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum"); CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt"); psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], W[n * LORA_R_TILES + r], psum); CHECK_NAN(psum, "apply_lora_down.psum"); } } } return lora_act; }; int dummy = 0; for (int k1 = 0; k1 < rank / WARP_R; k1 += NUM_STAGES) { #pragma unroll for (int k2 = 0; k2 < NUM_STAGES; k2++) { if (k1 + k2 >= rank / WARP_R) { break; } int nextk = k1 + k2 + NUM_STAGES - 1; int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES; bool pred = nextk < rank / WARP_R; if (alwaysfalse) { wgt += kernels::bit_cast(lora_wgt[k2][0].data[0]); } if (alwaysfalse) { dummy = clock(); } load_lora_wgt(wgt, nextk, rank, lora_wgt[idx], pred); if (alwaysfalse) { dummy = clock(); } lora_act_warp lora_act = compute(lora_wgt[k2], fpsum); reduce_lora_act(act, k1 + k2, lora_act, true); } } #pragma unroll for (int k = 0; k < NUM_STAGES - 1; k++) { #pragma unroll for (auto &&data : lora_wgt[k]) { #pragma unroll for (int i = 0; i < 4; i++) { dummy ^= kernels::bit_cast(data.data[i]); } } } unused_var(dummy, alwaysfalse); } __device__ __forceinline__ void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, const Arguments &args) { const int bm = binfo.bm; const int bn = binfo.bn; apply_lora_down( fpsum, args.lora_act + bm * (args.rank / WARP_R) * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), args.lora_wgt_down + bn * (BLOCK_N / 16) * (args.rank / 16) * WARP_SIZE, args.rank, args.alwaysfalse ); } }; }; }; // namespace nunchaku::kernels