Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
0f90515c
Commit
0f90515c
authored
Mar 09, 2026
by
wooway777
Browse files
issue/1065 - fix mha kv cache interface
parent
456ee3e1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
10 deletions
+19
-10
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
+19
-10
No files found.
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
View file @
0f90515c
...
...
@@ -36,7 +36,7 @@ void run(void *planned_meta) {
c10
::
cuda
::
CUDAStreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
auto
*
p
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
auto
out
=
std
::
optional
<
at
::
T
ensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
out
)
)
;
auto
out
_t
ensor
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
out
);
auto
q
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
q
);
auto
k_cache
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
k_cache
);
auto
v_cache
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
v_cache
);
...
...
@@ -46,7 +46,6 @@ void run(void *planned_meta) {
?
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
*
p
->
alibi_slopes
))
:
std
::
nullopt
;
// No new KV tokens to append (pure decode, KV already written to cache).
std
::
optional
<
const
at
::
Tensor
>
k_new
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
v_new
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
rotary_cos
=
std
::
nullopt
;
...
...
@@ -54,7 +53,14 @@ void run(void *planned_meta) {
std
::
optional
<
const
at
::
Tensor
>
cache_batch_idx
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
leftpad_k
=
std
::
nullopt
;
flash
::
mha_fwd_kvcache
(
const
bool
use_dynamic_out
=
q
.
dim
()
==
4
&&
k_cache
.
dim
()
==
4
&&
q
.
size
(
1
)
==
1
&&
q
.
size
(
2
)
>
k_cache
.
size
(
2
)
&&
q
.
size
(
3
)
%
8
==
0
&&
!
alibi_slopes
.
has_value
();
auto
out
=
use_dynamic_out
?
std
::
optional
<
at
::
Tensor
>
(
std
::
nullopt
)
:
std
::
optional
<
at
::
Tensor
>
(
out_tensor
);
auto
result
=
flash
::
mha_fwd_kvcache
(
q
,
k_cache
,
v_cache
,
...
...
@@ -69,13 +75,16 @@ void run(void *planned_meta) {
alibi_slopes
,
out
,
p
->
scale
,
true
,
// is_causal
-
1
,
// window_size_left (-1 = no sliding window)
-
1
,
// window_size_right
0.0
f
,
// softcap
false
,
// is_rotary_interleaved
0
// num_splits (0 = auto)
);
true
,
-
1
,
-
1
,
0.0
f
,
false
,
0
);
if
(
use_dynamic_out
)
{
out_tensor
.
copy_
(
result
[
0
]);
}
#else
throw
std
::
runtime_error
(
"FlashAttention is not enabled in this build"
);
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment