#include "zgemm.h" #include "gemm_w4a4_launch.cuh" namespace nunchaku::kernels { // for sm_75 only struct FasterI2FMode { enum Mode { Disabled = 0, Enabled, Always, }; inline static Mode mode = Disabled; static bool check(bool act_unsigned); }; template static void invoke_launch(Tensor::ScalarType dtype, bool use_fp4, bool fasterI2F, F &&launch) { if (fasterI2F && dtype == Tensor::FP16) { launch.template operator()(); } else { dispatchBool(use_fp4, [&]() { if (dtype == Tensor::FP16) { launch.template operator()(); } else if (dtype == Tensor::BF16) { launch.template operator()(); } else { assert(false); } }); } } void gemm_w4a4( Tensor act, // packed act [M, K / 2] Tensor wgt, // packed act [N, K / 2] Tensor out, // linear [M, N] Tensor qout, // packed act [M, N / 2] Tensor ascales, // packed as [K / 64, M] Tensor wscales, // packed ws [K / 64, N] Tensor oscales, // packed as [N / 64, M] Tensor poolout, // linear [M / PoolSize, N] Tensor lora_act_in, // packed lora_act [M, R] Tensor lora_up, // packed lora_wgt [N, R] Tensor lora_down, // packed lora_wgt [N, R] Tensor lora_act_out, // packed lora_act [M, R] Tensor norm_q, // linear [HEAD_DIM] Tensor norm_k, // linear [HEAD_DIM] Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] Tensor bias, // packed ws [N] Tensor smooth_factor, // packed ws [N], for quantization of the next layer Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim] Tensor out_linearattn,// linear [B, (M), N / 3] bool act_unsigned, std::vector lora_scales, // [R / 16] bool fuse_silu, bool fp4, float alpha, Tensor wcscales, Tensor out_q, // packed attention [B, H, M, D] Tensor out_k, // packed attention [B, H, M, D] Tensor out_v, // packed attention [B, H, M, D] int attn_tokens ) { Tensor::ScalarType dtype = Tensor::INVALID_SCALAR_TYPE; if (!fp4) { dtype = ascales.dtype(); } else { for (auto tensor : {out, bias, lora_up, lora_down, poolout, wcscales}) { if (tensor.valid()) { assert(dtype == Tensor::INVALID_SCALAR_TYPE || dtype == tensor.dtype()); dtype = tensor.dtype(); } } } invoke_launch(dtype, fp4, FasterI2FMode::check(act_unsigned), [&]() { GEMM_W4A4_Launch::gemm_w4a4( act, wgt, out, qout, ascales, wscales, oscales, poolout, lora_act_in, lora_up, lora_down, lora_act_out, norm_q, norm_k, rotary_emb, bias, smooth_factor, out_vk, out_linearattn, act_unsigned, lora_scales, fuse_silu, fp4, alpha, wcscales, out_q, out_k, out_v, attn_tokens ); }); } void linearattn_vk_mul_q(Tensor q, Tensor vk) { invoke_launch(q.dtype(), false, false, [&]() { GEMM_W4A4_Launch::linearattn_vk_mul_q(q, vk); }); } void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) { invoke_launch(input.dtype(), fp4, false, [&]() { GEMM_W4A4_Launch::quantize_w4a4_act_fuse_lora( input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4 ); }); } void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) { invoke_launch(input.dtype(), false, false, [&]() { GEMM_W4A4_Launch::quantize_w4a4_act( input, output, oscales ); }); } void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) { invoke_launch(input.dtype(), false, false, [&]() { GEMM_W4A4_Launch::quantize_w4a4_wgt( input, output, oscales ); }); } bool FasterI2FMode::check(bool act_unsigned) { auto *prop = getCurrentDeviceProperties(); if (prop->major != 7 || prop->minor != 5) { return false; } if (mode == Always) { return true; } else if (mode == Enabled && !act_unsigned) { return true; } else { return false; } } void set_faster_i2f_mode(std::string mode) { static const std::map mapping = { {"disabled", FasterI2FMode::Disabled}, {"enabled", FasterI2FMode::Enabled}, {"always", FasterI2FMode::Always}, }; FasterI2FMode::mode = mapping.at(mode); } };