Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
a256e8d9
Unverified
Commit
a256e8d9
authored
Mar 11, 2026
by
suss
Committed by
GitHub
Mar 11, 2026
Browse files
add mha_kvcache (#261)
* add mha_kvcache * repair gqa-api bug
parent
6ab9ee22
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
9 deletions
+29
-9
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+29
-9
No files found.
csrc/models/llama/llama_attention.cpp
View file @
a256e8d9
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mul.hpp"
#include "infinicore/ops/mul.hpp"
...
@@ -331,16 +332,35 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
...
@@ -331,16 +332,35 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
scaling_
);
scaling_
);
}
}
}
else
{
}
else
{
infinicore
::
op
::
paged_attention_
(
if
(
attention_backend_
==
backends
::
AttentionBackend
::
FlashAttn
)
{
attn_output
,
// FA2 decode path: flash::mha_fwd_kvcache
q_reshaped
,
// In paged-attn mode, seq_len = actual batch_size (one query token per sequence).
k_total
,
// q_reshaped: [seq_len, num_heads, head_dim] → [seq_len, 1, num_heads, head_dim]
v_total
,
// k/v cache: [num_blocks, num_kv_heads, block_size, head_dim]
block_tables
.
value
(),
// → permute {0,2,1,3} → [num_blocks, block_size, num_kv_heads, head_dim]
total_sequence_lengths
.
value
(),
auto
q_for_fa
=
q_reshaped
->
view
({
seq_len
,
1
,
num_attention_heads_
,
head_dim_
});
std
::
nullopt
,
auto
attn_out_4d
=
infinicore
::
op
::
mha_kvcache
(
scaling_
);
q_for_fa
,
k_total
->
permute
({
0
,
2
,
1
,
3
}),
// [num_blocks, block_size, num_kv_heads, head_dim]
v_total
->
permute
({
0
,
2
,
1
,
3
}),
total_sequence_lengths
.
value
(),
// [seq_len] int32 (one entry per sequence)
block_tables
.
value
(),
// [seq_len, max_num_blocks_per_seq] int32
std
::
nullopt
,
scaling_
);
attn_output
=
attn_out_4d
->
view
({
seq_len
,
num_attention_heads_
,
head_dim_
});
}
else
{
infinicore
::
op
::
paged_attention_
(
attn_output
,
q_reshaped
,
k_total
,
v_total
,
block_tables
.
value
(),
total_sequence_lengths
.
value
(),
std
::
nullopt
,
scaling_
);
}
}
}
// 7. Project output
// 7. Project output
attn_output
attn_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment