#pragma once #include "common.h" #include "Tensor.h" namespace nunchaku::kernels { void gemm_w4a4( Tensor act, // packed act [M, K / 2] Tensor wgt, // packed act [N, K / 2] Tensor out, // linear [M, N] Tensor qout, // packed act [M, N / 2] Tensor ascales, // packed as [K / 64, M] Tensor wscales, // packed ws [K / 64, N] Tensor oscales, // packed as [N / 64, M] Tensor poolout, // linear [M / PoolSize, N] Tensor lora_act_in, // packed lora_act [M, R] Tensor lora_up, // packed lora_wgt [N, R] Tensor lora_down, // packed lora_wgt [N, R] Tensor lora_act_out, // packed lora_act [M, R] Tensor norm_q, // linear [HEAD_DIM] Tensor norm_k, // linear [HEAD_DIM] Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2] Tensor bias, // packed ws [N] Tensor smooth_factor, // packed ws [N], for quantization of the next layer Tensor out_vk, // linear [B, num_heads, head_dim + 1, head_dim] Tensor out_linearattn,// linear [B, (M), N / 3] bool act_unsigned, std::vector lora_scales, // [R / 16] bool fuse_silu, bool fp4, float alpha, Tensor wcscales, Tensor out_q, // packed attention [B, H, M, D] Tensor out_k, // packed attention [B, H, M, D] Tensor out_v, // packed attention [B, H, M, D] int attn_tokens ); void linearattn_vk_mul_q(Tensor q, Tensor vk); void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {}, bool fuse_glu = false, bool fp4 = false); void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales); void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales); void gemm_w8a8(Tensor act, // [M, K] Tensor wgt, // [N, K] Tensor out, // [M, N] Tensor ascales, // [1, M] Tensor wscales, // [1, N] Tensor bias // packed ws [N] ); void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu); // void gemm_w8a8_fuse_litela( // Tensor act, // [B, (M), K] // Tensor wgt, // [N, K] // Tensor out_q, // [B, (M), N / 3] // Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim] // Tensor ascales, // [1, M] // Tensor wscales // [1, N] // ); void attention_fp16( Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM] Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM] Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM] float scale ); // EXPERIMENTAL, for sm_75 void set_faster_i2f_mode(std::string mode); // FOR TEST ONLY void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb); void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens); }; // namespace nunchaku::kernels