Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
665f383b
Commit
665f383b
authored
Mar 08, 2026
by
suss
Committed by
wooway777
Mar 11, 2026
Browse files
issue/1065 - add mha_kvcache
parent
21c6af2d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
194 additions
and
0 deletions
+194
-0
include/infinicore/ops/mha_kvcache.hpp
include/infinicore/ops/mha_kvcache.hpp
+51
-0
src/infinicore/ops/mha_kvcache/mha_kvcache.cc
src/infinicore/ops/mha_kvcache/mha_kvcache.cc
+58
-0
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
+85
-0
No files found.
include/infinicore/ops/mha_kvcache.hpp
0 → 100644
View file @
665f383b
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>
namespace
infinicore
::
op
{
// Flash Attention KV-cache decode op.
//
// Wraps flash::mha_fwd_kvcache for single-step (decode) attention over a
// paged KV cache.
//
// Tensor shapes:
// out : [batch_size, seqlen_q, num_heads, head_size]
// q : [batch_size, seqlen_q, num_heads, head_size]
// k_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// v_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// seqlens_k : [batch_size] int32 — total KV length per request
// block_table : [batch_size, max_num_blocks_per_seq] int32
INFINICORE_GRAPH_OP_CLASS
(
MhaKVCache
,
Tensor
,
// out
const
Tensor
&
,
// q
const
Tensor
&
,
// k_cache
const
Tensor
&
,
// v_cache
const
Tensor
&
,
// seqlens_k
const
Tensor
&
,
// block_table
std
::
optional
<
Tensor
>
,
// alibi_slopes
float
);
// scale
Tensor
mha_kvcache
(
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
mha_kvcache_
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
}
// namespace infinicore::op
src/infinicore/ops/mha_kvcache/mha_kvcache.cc
0 → 100644
View file @
665f383b
#include "infinicore/ops/mha_kvcache.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL
(
MhaKVCache
);
MhaKVCache
::
MhaKVCache
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
);
INFINICORE_GRAPH_OP_DISPATCH
(
out
->
device
().
getType
(),
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes
,
scale
);
}
void
MhaKVCache
::
execute
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_GRAPH_OP_RECORD_OR_RUN
(
MhaKVCache
,
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes
,
scale
);
}
void
mha_kvcache_
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
MhaKVCache
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes
,
scale
);
}
Tensor
mha_kvcache
(
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
// Output shape matches q: [batch_size, seqlen_q, num_heads, head_size]
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
mha_kvcache_
(
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes
,
scale
);
return
out
;
}
}
// namespace infinicore::op
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
0 → 100644
View file @
665f383b
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
namespace
infinicore
::
op
::
mha_kvcache_impl
::
flashattn
{
struct
PlannedMeta
{
graph
::
GraphTensor
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
;
std
::
optional
<
graph
::
GraphTensor
>
alibi_slopes
;
float
scale
;
};
void
*
plan
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
return
new
PlannedMeta
{
graph
::
GraphTensor
(
out
),
graph
::
GraphTensor
(
q
),
graph
::
GraphTensor
(
k_cache
),
graph
::
GraphTensor
(
v_cache
),
graph
::
GraphTensor
(
seqlens_k
),
graph
::
GraphTensor
(
block_table
),
alibi_slopes
?
std
::
optional
<
graph
::
GraphTensor
>
(
graph
::
GraphTensor
(
*
alibi_slopes
))
:
std
::
nullopt
,
scale
};
}
void
run
(
void
*
planned_meta
)
{
c10
::
cuda
::
CUDAStreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
auto
*
p
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
auto
out
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
out
));
auto
q
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
q
);
auto
k_cache
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
k_cache
);
auto
v_cache
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
v_cache
);
auto
seqlens_k
=
std
::
optional
<
const
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
seqlens_k
));
auto
block_table
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
block_table
));
auto
alibi_slopes
=
p
->
alibi_slopes
?
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
*
p
->
alibi_slopes
))
:
std
::
nullopt
;
// No new KV tokens to append (pure decode, KV already written to cache).
std
::
optional
<
const
at
::
Tensor
>
k_new
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
v_new
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
rotary_cos
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
rotary_sin
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
cache_batch_idx
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
leftpad_k
=
std
::
nullopt
;
flash
::
mha_fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k_new
,
v_new
,
seqlens_k
,
rotary_cos
,
rotary_sin
,
cache_batch_idx
,
leftpad_k
,
block_table
,
alibi_slopes
,
out
,
p
->
scale
,
true
,
// is_causal
-
1
,
// window_size_left (-1 = no sliding window)
-
1
,
// window_size_right
0.0
f
,
// softcap
false
,
// is_rotary_interleaved
0
// num_splits (0 = auto)
);
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE
(
MhaKVCache
,
&
plan
,
&
run
,
&
cleanup
);
}
// namespace infinicore::op::mha_kvcache_impl::flashattn
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment