Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
a256e8d9
Unverified
Commit
a256e8d9
authored
Mar 11, 2026
by
suss
Committed by
GitHub
Mar 11, 2026
Browse files
add mha_kvcache (#261)
* add mha_kvcache * repair gqa-api bug
parent
6ab9ee22
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
9 deletions
+29
-9
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+29
-9
No files found.
csrc/models/llama/llama_attention.cpp
View file @
a256e8d9
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mul.hpp"
#include "infinicore/ops/mul.hpp"
...
@@ -330,6 +331,23 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
...
@@ -330,6 +331,23 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
std
::
nullopt
,
std
::
nullopt
,
scaling_
);
scaling_
);
}
}
}
else
{
if
(
attention_backend_
==
backends
::
AttentionBackend
::
FlashAttn
)
{
// FA2 decode path: flash::mha_fwd_kvcache
// In paged-attn mode, seq_len = actual batch_size (one query token per sequence).
// q_reshaped: [seq_len, num_heads, head_dim] → [seq_len, 1, num_heads, head_dim]
// k/v cache: [num_blocks, num_kv_heads, block_size, head_dim]
// → permute {0,2,1,3} → [num_blocks, block_size, num_kv_heads, head_dim]
auto
q_for_fa
=
q_reshaped
->
view
({
seq_len
,
1
,
num_attention_heads_
,
head_dim_
});
auto
attn_out_4d
=
infinicore
::
op
::
mha_kvcache
(
q_for_fa
,
k_total
->
permute
({
0
,
2
,
1
,
3
}),
// [num_blocks, block_size, num_kv_heads, head_dim]
v_total
->
permute
({
0
,
2
,
1
,
3
}),
total_sequence_lengths
.
value
(),
// [seq_len] int32 (one entry per sequence)
block_tables
.
value
(),
// [seq_len, max_num_blocks_per_seq] int32
std
::
nullopt
,
scaling_
);
attn_output
=
attn_out_4d
->
view
({
seq_len
,
num_attention_heads_
,
head_dim_
});
}
else
{
}
else
{
infinicore
::
op
::
paged_attention_
(
infinicore
::
op
::
paged_attention_
(
attn_output
,
attn_output
,
...
@@ -341,6 +359,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
...
@@ -341,6 +359,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
std
::
nullopt
,
std
::
nullopt
,
scaling_
);
scaling_
);
}
}
}
// 7. Project output
// 7. Project output
attn_output
attn_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment