Unverified Commit a256e8d9 authored by suss's avatar suss Committed by GitHub
Browse files

add mha_kvcache (#261)

* add mha_kvcache

* repair gqa-api bug
parent 6ab9ee22
......@@ -4,6 +4,7 @@
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mul.hpp"
......@@ -331,16 +332,35 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
scaling_);
}
} else {
infinicore::op::paged_attention_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
std::nullopt,
scaling_);
if (attention_backend_ == backends::AttentionBackend::FlashAttn) {
// FA2 decode path: flash::mha_fwd_kvcache
// In paged-attn mode, seq_len = actual batch_size (one query token per sequence).
// q_reshaped: [seq_len, num_heads, head_dim] → [seq_len, 1, num_heads, head_dim]
// k/v cache: [num_blocks, num_kv_heads, block_size, head_dim]
// → permute {0,2,1,3} → [num_blocks, block_size, num_kv_heads, head_dim]
auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_});
auto attn_out_4d = infinicore::op::mha_kvcache(
q_for_fa,
k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim]
v_total->permute({0, 2, 1, 3}),
total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence)
block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32
std::nullopt,
scaling_);
attn_output = attn_out_4d->view({seq_len, num_attention_heads_, head_dim_});
} else {
infinicore::op::paged_attention_(
attn_output,
q_reshaped,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
std::nullopt,
scaling_);
}
}
// 7. Project output
attn_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment