#include "gemm_w4a4_launch.cuh" namespace nunchaku::kernels { #ifndef __INTELLISENSE__ template void GEMM_W4A4_Launch::gemm_w4a4( #else template<> void GEMM_W4A4_Launch::gemm_w4a4( #endif Tensor act, // packed act [M, K / 2] Tensor wgt, // packed act [N, K / 2] Tensor out, // linear [M, N] Tensor qout, // packed act [M, N / 2] Tensor ascales, // packed as [K / 64, M] Tensor wscales, // packed ws [K / 64, N] Tensor oscales, // packed as [N / 64, M] Tensor poolout, // linear [M / PoolSize, N] Tensor lora_act_in, // packed lora_act [M, R] Tensor lora_up, // packed lora_wgt [N, R] Tensor lora_down, // packed lora_wgt [N, R] Tensor lora_act_out, // packed lora_act [M, R] Tensor norm_q, // linear [HEAD_DIM] Tensor norm_k, // linear [HEAD_DIM] Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] Tensor bias, // packed ws [N] Tensor smooth_factor, // packed ws [N], for quantization of the next layer Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim] Tensor out_linearattn,// linear [B, (M), N / 3] bool act_unsigned, std::vector lora_scales, // [R / 16] bool fuse_silu, bool fp4, float alpha, Tensor wcscales // packed ws [N] ) { int M = act.numel() / act.shape[-1]; int N = wgt.shape[0]; int K = act.shape[-1] * 2; assert(K == wgt.shape[1] * 2); int actualM = 0; int actualN = 0; if (out.valid()) { actualM = out.numel() / out.shape[-1]; actualN = out.shape[-1]; assert(actualM <= M && M - actualM < GEMM::BLOCK_M); assert(actualN <= N && N - actualN < GEMM::BLOCK_N); } spdlog::trace("gemm_w4a4: M={} N={} K={}", M, N, K); spdlog::trace("act at {}", act.data_ptr()); spdlog::trace("wgt at {}", wgt.data_ptr()); spdlog::trace("ascales at {}", ascales.data_ptr()); spdlog::trace("wscales at {}", wscales.data_ptr()); if (bias.valid()) { spdlog::trace("bias at {}", bias.data_ptr()); } int shmem = 0; auto launch = [&](Epilogue::Arguments args) { assert(M % GEMM::BLOCK_M == 0); assert(N % GEMM::BLOCK_N == 0); dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N); bool swapBlockMN = M > N * 2; if (swapBlockMN) { std::swap(grid.x, grid.y); } dispatchBool(fp4, [&]() { // test_sizeof(); // std::apply([](auto ...args) { // (test_sizeof(), ...); // }, args); // constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200; if constexpr (!USE_FP4) { dispatchBool(act_unsigned, [&]() { auto func = invoke_kernel, const packed_act_t *, const packed_wgt_t *, const packed_ascale_t *, const packed_wscale_t *, int, int, int, typename Epilogue::Arguments, bool, bool>; if (shmem >= 24 * 1024) { checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); } assert(alpha == 1.0f); func<<>>( act.data_ptr(), wgt.data_ptr(), ascales.data_ptr(), wscales.data_ptr(), M, N, K, args, swapBlockMN, false ); checkCUDA(cudaGetLastError()); }); return; } if constexpr (USE_FP4) { dispatchBool(alpha != 1.0f, [&]() { assert(!act_unsigned); auto func = invoke_kernel, const packed_act_t *, const packed_wgt_t *, const packed_amscale_t *, const packed_wmscale_t *, float, int, int, int, typename Epilogue::Arguments, bool, bool>; if (shmem >= 24 * 1024) { checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); } assert(ascales.dtype() == Tensor::FP8_E4M3); assert(wscales.dtype() == Tensor::FP8_E4M3); func<<>>( act.data_ptr(), wgt.data_ptr(), ascales.data_ptr(), wscales.data_ptr(), alpha, M, N, K, args, swapBlockMN, false ); checkCUDA(cudaGetLastError()); }); return; } // if constexpr (USE_FP4 && !FP4_AVAILABLE) { // throw std::runtime_error("FP4 kernel is not available"); // } }); }; auto launch_bias = [&](NextEpilogue::Arguments nextArgs) { assert(!bias.valid() || bias.numel() == N); assert(!wcscales.valid() || wcscales.numel() == N); dispatchBool(bias.valid(), [&]() { dispatchBool(wcscales.valid(), [&]() { using EpilogueBias = typename GEMM::EpilogueBias; // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows // ** sizeof(std::tuple>) == 8 on device ** using Epilogue = typename GEMM::EpilogueCombination; return launch.template operator()({ typename EpilogueBias::Arguments{ .bias = USE_BIAS ? bias.data_ptr() : nullptr, .scale = USE_SCALE ? wcscales.data_ptr() : nullptr, }, nextArgs, {} }); }); }); }; // auto launch_bias = launch; auto launch_lora = [&](NextEpilogue::Arguments nextArgs, MidEpilogue::Arguments midArgs) { assert(lora_up.valid() == lora_act_in.valid()); assert(lora_down.valid() == lora_act_out.valid()); if (!lora_up.valid()) { assert(!lora_down.valid()); return launch_bias.template operator()>({midArgs, nextArgs}); } const int rank_up = lora_up.shape[1]; assert(lora_up.shape[0] == N); // assert(lora_up.shape[1] == Lora::LORA_RANK); assert(lora_act_in.shape[0] == M); assert(lora_act_in.shape[1] == rank_up); dispatchVal(rank_up, LoraRanks(), [&]() { using LoraUp = typename GEMM::Lora; using scale_t = typename LoraUp::scale_t; scale_t scales; if constexpr (scales.size() > 0) { assert(lora_scales.size() >= scales.size()); for (size_t i = 0; i < scales.size(); i++) { scales[i] = lora_scales[i]; } } if (!lora_down.valid()) { using Epilogue = typename GEMM::EpilogueCombination; return launch_bias.template operator()({ typename LoraUp::EpilogueLoraUp::Arguments{ .lora_act = lora_act_in.data_ptr(), .lora_wgt_up = lora_up.data_ptr(), .scales = scales, }, midArgs, nextArgs, {} }); } const int rank_down = lora_down.shape[1]; assert(rank_down == rank_up); assert(lora_down.shape[0] == N); // assert(lora_down.shape[1] == Lora::LORA_RANK); assert(lora_act_out.shape[0] == M); assert(lora_act_out.shape[1] == rank_down); lora_act_out.zero_(); // dispatchVal(rank_down, std::integer_sequence(), [&]() { using LoraDown = LoraUp; // GEMM::Lora; using Epilogue = typename GEMM::EpilogueCombination; return launch_bias.template operator()({ typename LoraUp::EpilogueLoraUp::Arguments{ .lora_act = lora_act_in.data_ptr(), .lora_wgt_up = lora_up.data_ptr(), .scales = scales, }, midArgs, typename LoraDown::EpilogueLoraDown::Arguments{ .lora_wgt_down = lora_down.data_ptr(), .lora_act = lora_act_out.data_ptr(), }, nextArgs, {} }); // }); }); }; if (qout.valid() && oscales.valid()) { // dispatchBool(qout_unsigned, [&]() { static constexpr float SHIFT_GELU = 0.171875f; dispatchBool(fp4, [&]() { constexpr bool USE_UNSIGNED = !USE_FP4; using EpilogueQuantize = typename GEMM::EpilogueQuantize; auto argsQuantize = typename EpilogueQuantize::Arguments{ .qout = qout.data_ptr(), .oscales = oscales.data_ptr(), .shift_value = USE_FP4 ? 0.0f : SHIFT_GELU, .smooth_factor = smooth_factor.data_ptr() }; // TODO: check if gelu is needed if (out.valid()) { launch_lora.template operator(), typename GEMM::EpilogueGelu>({ typename GEMM::EpilogueDefault::Arguments{ .out = out.data_ptr(), .actualM = actualM, .actualN = actualN, }, argsQuantize }, {}); } else { launch_lora.template operator()(argsQuantize, {}); } }); } else if (out_linearattn.valid()) { assert(out_vk.valid()); using Epilogue = typename GEMM::EpilogueLiteLA; assert(out_vk.dtype() == Tensor::FP32); assert(out_vk.ndims() == 4); assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1); assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM); assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N); int batch_size = out_vk.shape[0]; int num_heads = out_vk.shape[1]; assert(isTypeMatch(out_linearattn.dtype())); assert(out_linearattn.ndims() == 3); assert(out_linearattn.shape[0] == batch_size); assert(out_linearattn.shape[2] * 3 == N); int num_tokens = out_linearattn.shape[1]; assert(num_tokens % GEMM::BLOCK_M == 0); int num_blocks_per_batch = ceilDiv(num_tokens, GEMM::BLOCK_M); shmem = std::max(shmem, Epilogue::SHMEM_SIZE); out_vk.zero_(); launch_lora.template operator()(typename Epilogue::Arguments{ .out_q = out_linearattn.data_ptr(), .out_vk = out_vk.data_ptr(), .num_blocks_per_batch = num_blocks_per_batch, .actualM = M, }, {}); } else if (rotary_emb.valid()) { assert(norm_q.valid()); assert(norm_k.valid()); // assert(isTypeMatch(rotary_emb.scalar_type())); assert(rotary_emb.scalar_type() == Tensor::FP32); assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); launch_lora.template operator()(typename GEMM::EpilogueQKVProj::Arguments{ .out = out.data_ptr(), .actualM = actualM, .actualN = actualN, .pool_out = poolout.valid() ? poolout.data_ptr() : nullptr, .rotary_emb = rotary_emb.data_ptr(), .rmsnorm_weight_q = norm_q.data_ptr(), .rmsnorm_weight_k = norm_k.data_ptr(), .epsilon = 1e-6, }, {}); } else if (out.valid()) { using Epilogue = typename GEMM::EpilogueDefault; typename Epilogue::Arguments args{ .out = out.data_ptr(), .actualM = actualM, .actualN = actualN, }; if (fuse_silu) { launch_lora.template operator()(args, {}); } else { launch_lora.template operator()(args, {}); } } else { assert(false); } } template void GEMM_W4A4_Launch::linearattn_vk_mul_q(Tensor q, Tensor vk) { using Epilogue = typename GEMM::EpilogueLiteLA; int batch_size = vk.shape[0]; int num_heads = vk.shape[1]; int num_tokens = q.shape[1]; assert(isTypeMatch(q.scalar_type())); assert(vk.scalar_type() == Tensor::FP32); int BLOCK_SIZE; if (num_tokens % 256 == 0) { BLOCK_SIZE = 256; } else { BLOCK_SIZE = 128; } invoke_kernel<<>>( q.data_ptr(), vk.data_ptr(), 1e-6f, num_tokens ); checkCUDA(cudaGetLastError()); } template void GEMM_W4A4_Launch::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) { const int actualM = input.numel() / input.shape[-1]; const int actualN = input.shape[-1]; const int M = ceilDiv(actualM, GEMM::BLOCK_M) * GEMM::BLOCK_M; const int N = ceilDiv(actualN / (fuse_glu ? 2 : 1), GEMM::BLOCK_N) * GEMM::BLOCK_N; assert(output.dtype() == Tensor::INT8); assert(output.numel() / output.shape[-1] == M); assert(output.shape[-1] == N / 2); // assert(oscales.dtype() == Tensor::FP16); if (fp4) { assert(oscales.dtype() == Tensor::FP8_E4M3); assert(oscales.numel() == M * N / GEMM::WARP_K * 4); } else { assert(isTypeMatch(oscales.dtype())); assert(oscales.numel() == M * N / GEMM::WARP_K); } const int rank = lora_down.shape[1]; assert(lora_down.shape[0] == N); // assert(lora_down.shape[1] == Lora::LORA_RANK); assert(lora_act_out.shape[0] == M); assert(lora_act_out.shape[1] == rank); lora_act_out.zero_(); dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N); dispatchVal(rank, LoraRanks(), [&]() { dispatchBool(fuse_glu, [&]() { dispatchBool(fp4, [&]() { using Lora = typename GEMM::Lora; using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel; auto func = invoke_kernel; checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE)); // log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel())); func<<>>( typename kernel::Arguments{ .input = input.data_ptr(), .smooth_factor = smooth.valid() ? smooth.data_ptr() : nullptr, .output = output.data_ptr(), .oscales = oscales.data_ptr(), .lora_wgt_down = lora_down.data_ptr(), .lora_act = lora_act_out.data_ptr(), .M = M, .N = N, .actualM = actualM, .actualN = actualN, } ); checkCUDA(cudaGetLastError()); }); }); }); } template void GEMM_W4A4_Launch::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) { int M = input.numel() / input.shape[-1]; int K = input.shape[-1]; assert(output.dtype() == Tensor::INT8); assert(output.numel() / output.shape[-1] == M); assert(output.shape[-1] == K / 2); // assert(oscales.dtype() == Tensor::FP16); assert(isTypeMatch(oscales.dtype())); assert(oscales.numel() == M * K / GEMM::WARP_K); dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K); invoke_kernel<<>>( input.data_ptr(), output.data_ptr(), oscales.data_ptr(), K ); checkCUDA(cudaGetLastError()); } template void GEMM_W4A4_Launch::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) { int N = input.numel() / input.shape[-1]; int K = input.shape[-1]; assert(output.dtype() == Tensor::INT8); assert(output.ndims() == 2); assert(output.shape[0] == N); assert(output.shape[1] == K / 2); assert(isTypeMatch(oscales.dtype())); // assert(oscales.dtype() == Tensor::FP16); assert(oscales.numel() == N * K / GEMM::WARP_K); dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K); invoke_kernel<<>>( input.data_ptr(), output.data_ptr(), oscales.data_ptr(), K ); checkCUDA(cudaGetLastError()); } }; // namespace nunchaku::kernels