#include "zgemm.h" #include "gemm_w4a4_launch.cuh" namespace nunchaku::kernels { template static void invoke_launch(Tensor::ScalarType dtype, F &&launch) { if (dtype == Tensor::FP16) { launch.template operator()(); } else if (dtype == Tensor::BF16) { launch.template operator()(); } else { assert(false); } } void gemm_w4a4( Tensor act, // packed act [M, K / 2] Tensor wgt, // packed act [N, K / 2] Tensor out, // linear [M, N] Tensor qout, // packed act [M, N / 2] Tensor ascales, // packed as [K / 64, M] Tensor wscales, // packed ws [K / 64, N] Tensor oscales, // packed as [N / 64, M] Tensor poolout, // linear [M / PoolSize, N] Tensor lora_act_in, // packed lora_act [M, R] Tensor lora_up, // packed lora_wgt [N, R] Tensor lora_down, // packed lora_wgt [N, R] Tensor lora_act_out, // packed lora_act [M, R] Tensor norm_q, // linear [HEAD_DIM] Tensor norm_k, // linear [HEAD_DIM] Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] Tensor bias, // packed ws [N] Tensor smooth_factor, // packed ws [N], for quantization of the next layer Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim] Tensor out_linearattn,// linear [B, (M), N / 3] bool act_unsigned, std::vector lora_scales, // [R / 16] bool fuse_silu ) { invoke_launch(ascales.dtype(), [&]() { GEMM_W4A4_Launch::gemm_w4a4( act, wgt, out, qout, ascales, wscales, oscales, poolout, lora_act_in, lora_up, lora_down, lora_act_out, norm_q, norm_k, rotary_emb, bias, smooth_factor, out_vk, out_linearattn, act_unsigned, lora_scales, fuse_silu ); }); } void linearattn_vk_mul_q(Tensor q, Tensor vk) { invoke_launch(q.dtype(), [&]() { GEMM_W4A4_Launch::linearattn_vk_mul_q(q, vk); }); } void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu) { invoke_launch(input.dtype(), [&]() { GEMM_W4A4_Launch::quantize_w4a4_act_fuse_lora( input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu ); }); } void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) { invoke_launch(input.dtype(), [&]() { GEMM_W4A4_Launch::quantize_w4a4_act( input, output, oscales ); }); } void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) { invoke_launch(input.dtype(), [&]() { GEMM_W4A4_Launch::quantize_w4a4_wgt( input, output, oscales ); }); } };