Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
17777b08
Unverified
Commit
17777b08
authored
Mar 11, 2026
by
pengcheng888
Committed by
GitHub
Mar 11, 2026
Browse files
Merge pull request #1066 from InfiniTensor/issue/1065
issue/1065 - introduce mha kvcache
parents
21c6af2d
0f90515c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
413 additions
and
6 deletions
+413
-6
include/infinicore/ops/mha_kvcache.hpp
include/infinicore/ops/mha_kvcache.hpp
+51
-0
python/infinicore/__init__.py
python/infinicore/__init__.py
+8
-6
python/infinicore/ops/mha_kvcache.py
python/infinicore/ops/mha_kvcache.py
+67
-0
src/infinicore/ops/mha_kvcache/mha_kvcache.cc
src/infinicore/ops/mha_kvcache/mha_kvcache.cc
+58
-0
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
+100
-0
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+2
-0
src/infinicore/pybind11/ops/mha_kvcache.hpp
src/infinicore/pybind11/ops/mha_kvcache.hpp
+127
-0
No files found.
include/infinicore/ops/mha_kvcache.hpp
0 → 100644
View file @
17777b08
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>
namespace
infinicore
::
op
{
// Flash Attention KV-cache decode op.
//
// Wraps flash::mha_fwd_kvcache for single-step (decode) attention over a
// paged KV cache.
//
// Tensor shapes:
// out : [batch_size, seqlen_q, num_heads, head_size]
// q : [batch_size, seqlen_q, num_heads, head_size]
// k_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// v_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// seqlens_k : [batch_size] int32 — total KV length per request
// block_table : [batch_size, max_num_blocks_per_seq] int32
INFINICORE_GRAPH_OP_CLASS
(
MhaKVCache
,
Tensor
,
// out
const
Tensor
&
,
// q
const
Tensor
&
,
// k_cache
const
Tensor
&
,
// v_cache
const
Tensor
&
,
// seqlens_k
const
Tensor
&
,
// block_table
std
::
optional
<
Tensor
>
,
// alibi_slopes
float
);
// scale
Tensor
mha_kvcache
(
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
void
mha_kvcache_
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
);
}
// namespace infinicore::op
python/infinicore/__init__.py
View file @
17777b08
...
@@ -61,6 +61,7 @@ from infinicore.ops.cross_entropy import cross_entropy
...
@@ -61,6 +61,7 @@ from infinicore.ops.cross_entropy import cross_entropy
from
infinicore.ops.equal
import
equal
from
infinicore.ops.equal
import
equal
from
infinicore.ops.kv_caching
import
kv_caching
from
infinicore.ops.kv_caching
import
kv_caching
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.mha_kvcache
import
mha_kvcache
from
infinicore.ops.mha_varlen
import
mha_varlen
from
infinicore.ops.mha_varlen
import
mha_varlen
from
infinicore.ops.mul
import
mul
from
infinicore.ops.mul
import
mul
from
infinicore.ops.narrow
import
narrow
from
infinicore.ops.narrow
import
narrow
...
@@ -131,16 +132,15 @@ __all__ = [
...
@@ -131,16 +132,15 @@ __all__ = [
"long"
,
"long"
,
"short"
,
"short"
,
"uint8"
,
"uint8"
,
# Operations.
# Operators.
"addcmul"
,
"atanh"
,
"binary_cross_entropy_with_logits"
,
"cdist"
,
"reciprocal"
,
"add"
,
"add"
,
"addcmul"
,
"add_rms_norm"
,
"add_rms_norm"
,
"add_rms_norm_"
,
"add_rms_norm_"
,
"atanh"
,
"attention"
,
"attention"
,
"binary_cross_entropy_with_logits"
,
"cdist"
,
"kv_caching"
,
"kv_caching"
,
"matmul"
,
"matmul"
,
"equal"
,
"equal"
,
...
@@ -156,11 +156,13 @@ __all__ = [
...
@@ -156,11 +156,13 @@ __all__ = [
"from_list"
,
"from_list"
,
"from_numpy"
,
"from_numpy"
,
"from_torch"
,
"from_torch"
,
"mha_kvcache"
,
"mha_varlen"
,
"mha_varlen"
,
"paged_caching"
,
"paged_caching"
,
"paged_attention"
,
"paged_attention"
,
"paged_attention_prefill"
,
"paged_attention_prefill"
,
"ones"
,
"ones"
,
"reciprocal"
,
"strided_empty"
,
"strided_empty"
,
"strided_from_blob"
,
"strided_from_blob"
,
"zeros"
,
"zeros"
,
...
...
python/infinicore/ops/mha_kvcache.py
0 → 100644
View file @
17777b08
from
typing
import
Optional
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
mha_kvcache
(
q
:
Tensor
,
k_cache
:
Tensor
,
v_cache
:
Tensor
,
seqlens_k
:
Tensor
,
block_table
:
Tensor
,
alibi_slopes
:
Optional
[
Tensor
]
=
None
,
scale
:
float
=
1.0
,
*
,
out
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""
Flash attention KV-cache decode for single-step attention over a paged KV cache.
This function performs attention decoding using a paged KV cache layout,
which is efficient for inference with large sequence lengths.
Args:
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k_cache: Key cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout)
v_cache: Value cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout)
seqlens_k: Total KV length per request of shape [batch_size] (int32)
block_table: Block mapping table of shape [batch_size, max_num_blocks_per_seq] (int32)
alibi_slopes: Optional ALiBi slopes tensor, if None then ALiBi is disabled
scale: Scaling factor for attention scores (typically 1.0/sqrt(head_size))
out: Optional output tensor. If provided, the operation will be performed in-place.
Returns:
Output tensor of shape [batch_size, seqlen_q, num_heads, head_size]
Note:
The KV cache uses a paged layout where:
- k_cache and v_cache are organized into fixed-size blocks
- block_table maps logical positions to physical blocks for each sequence
- seqlens_k indicates the current total length of each sequence in the cache
"""
if
out
is
None
:
return
Tensor
(
_infinicore
.
mha_kvcache
(
q
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
seqlens_k
.
_underlying
,
block_table
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
)
_infinicore
.
mha_kvcache_
(
out
.
_underlying
,
q
.
_underlying
,
k_cache
.
_underlying
,
v_cache
.
_underlying
,
seqlens_k
.
_underlying
,
block_table
.
_underlying
,
alibi_slopes
.
_underlying
if
alibi_slopes
is
not
None
else
None
,
scale
,
)
return
out
src/infinicore/ops/mha_kvcache/mha_kvcache.cc
0 → 100644
View file @
17777b08
#include "infinicore/ops/mha_kvcache.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL
(
MhaKVCache
);
MhaKVCache
::
MhaKVCache
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
);
INFINICORE_GRAPH_OP_DISPATCH
(
out
->
device
().
getType
(),
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes
,
scale
);
}
void
MhaKVCache
::
execute
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
INFINICORE_GRAPH_OP_RECORD_OR_RUN
(
MhaKVCache
,
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes
,
scale
);
}
void
mha_kvcache_
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
MhaKVCache
::
execute
(
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes
,
scale
);
}
Tensor
mha_kvcache
(
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
// Output shape matches q: [batch_size, seqlen_q, num_heads, head_size]
auto
out
=
Tensor
::
empty
(
q
->
shape
(),
q
->
dtype
(),
q
->
device
());
mha_kvcache_
(
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes
,
scale
);
return
out
;
}
}
// namespace infinicore::op
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
0 → 100644
View file @
17777b08
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace
infinicore
::
op
::
mha_kvcache_impl
::
flashattn
{
struct
PlannedMeta
{
graph
::
GraphTensor
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
;
std
::
optional
<
graph
::
GraphTensor
>
alibi_slopes
;
float
scale
;
};
void
*
plan
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
return
new
PlannedMeta
{
graph
::
GraphTensor
(
out
),
graph
::
GraphTensor
(
q
),
graph
::
GraphTensor
(
k_cache
),
graph
::
GraphTensor
(
v_cache
),
graph
::
GraphTensor
(
seqlens_k
),
graph
::
GraphTensor
(
block_table
),
alibi_slopes
?
std
::
optional
<
graph
::
GraphTensor
>
(
graph
::
GraphTensor
(
*
alibi_slopes
))
:
std
::
nullopt
,
scale
};
}
void
run
(
void
*
planned_meta
)
{
#ifdef ENABLE_FLASH_ATTN
c10
::
cuda
::
CUDAStreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
auto
*
p
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
auto
out_tensor
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
out
);
auto
q
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
q
);
auto
k_cache
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
k_cache
);
auto
v_cache
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
v_cache
);
auto
seqlens_k
=
std
::
optional
<
const
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
seqlens_k
));
auto
block_table
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
block_table
));
auto
alibi_slopes
=
p
->
alibi_slopes
?
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
*
p
->
alibi_slopes
))
:
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
k_new
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
v_new
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
rotary_cos
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
rotary_sin
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
cache_batch_idx
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
leftpad_k
=
std
::
nullopt
;
const
bool
use_dynamic_out
=
q
.
dim
()
==
4
&&
k_cache
.
dim
()
==
4
&&
q
.
size
(
1
)
==
1
&&
q
.
size
(
2
)
>
k_cache
.
size
(
2
)
&&
q
.
size
(
3
)
%
8
==
0
&&
!
alibi_slopes
.
has_value
();
auto
out
=
use_dynamic_out
?
std
::
optional
<
at
::
Tensor
>
(
std
::
nullopt
)
:
std
::
optional
<
at
::
Tensor
>
(
out_tensor
);
auto
result
=
flash
::
mha_fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k_new
,
v_new
,
seqlens_k
,
rotary_cos
,
rotary_sin
,
cache_batch_idx
,
leftpad_k
,
block_table
,
alibi_slopes
,
out
,
p
->
scale
,
true
,
-
1
,
-
1
,
0.0
f
,
false
,
0
);
if
(
use_dynamic_out
)
{
out_tensor
.
copy_
(
result
[
0
]);
}
#else
throw
std
::
runtime_error
(
"FlashAttention is not enabled in this build"
);
#endif
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE
(
MhaKVCache
,
&
plan
,
&
run
,
&
cleanup
);
}
// namespace infinicore::op::mha_kvcache_impl::flashattn
src/infinicore/pybind11/ops.hpp
View file @
17777b08
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "ops/linear.hpp"
#include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp"
#include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp"
#include "ops/matmul.hpp"
#include "ops/mha_kvcache.hpp"
#include "ops/mha_varlen.hpp"
#include "ops/mha_varlen.hpp"
#include "ops/mul.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention.hpp"
...
@@ -54,6 +55,7 @@ inline void bind(py::module &m) {
...
@@ -54,6 +55,7 @@ inline void bind(py::module &m) {
bind_linear
(
m
);
bind_linear
(
m
);
bind_matmul
(
m
);
bind_matmul
(
m
);
bind_mul
(
m
);
bind_mul
(
m
);
bind_mha_kvcache
(
m
);
bind_mha_varlen
(
m
);
bind_mha_varlen
(
m
);
bind_hardswish
(
m
);
bind_hardswish
(
m
);
bind_hardtanh
(
m
);
bind_hardtanh
(
m
);
...
...
src/infinicore/pybind11/ops/mha_kvcache.hpp
0 → 100644
View file @
17777b08
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/mha_kvcache.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
Tensor
py_mha_kvcache
(
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
seqlens_k
,
Tensor
block_table
,
pybind11
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
return
op
::
mha_kvcache
(
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes_tensor
,
scale
);
}
void
py_mha_kvcache_
(
Tensor
out
,
Tensor
q
,
Tensor
k_cache
,
Tensor
v_cache
,
Tensor
seqlens_k
,
Tensor
block_table
,
pybind11
::
object
alibi_slopes
,
float
scale
)
{
std
::
optional
<
Tensor
>
alibi_slopes_tensor
=
std
::
nullopt
;
if
(
!
alibi_slopes
.
is_none
())
{
alibi_slopes_tensor
=
alibi_slopes
.
cast
<
Tensor
>
();
}
op
::
mha_kvcache_
(
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
,
alibi_slopes_tensor
,
scale
);
}
inline
void
bind_mha_kvcache
(
py
::
module
&
m
)
{
m
.
def
(
"mha_kvcache"
,
&
ops
::
py_mha_kvcache
,
py
::
arg
(
"q"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"seqlens_k"
),
py
::
arg
(
"block_table"
),
py
::
arg
(
"alibi_slopes"
),
py
::
arg
(
"scale"
),
R"doc(Flash attention KV-cache decode for single-step attention over a paged KV cache.
Parameters
----------
q : Tensor
Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k_cache : Tensor
Key cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout)
v_cache : Tensor
Value cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout)
seqlens_k : Tensor
Total KV length per request of shape [batch_size] (int32)
block_table : Tensor
Block mapping table of shape [batch_size, max_num_blocks_per_seq] (int32)
alibi_slopes : Optional[Tensor]
ALiBi slopes tensor, if None then ALiBi is disabled
scale : float
Scaling factor for attention scores (typically 1.0/sqrt(head_size))
Returns
-------
Tensor
Output tensor of shape [batch_size, seqlen_q, num_heads, head_size]
)doc"
);
m
.
def
(
"mha_kvcache_"
,
&
ops
::
py_mha_kvcache_
,
py
::
arg
(
"out"
),
py
::
arg
(
"q"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"seqlens_k"
),
py
::
arg
(
"block_table"
),
py
::
arg
(
"alibi_slopes"
),
py
::
arg
(
"scale"
),
R"doc(In-place flash attention KV-cache decode.
Parameters
----------
out : Tensor
Output tensor of shape [batch_size, seqlen_q, num_heads, head_size]
q : Tensor
Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k_cache : Tensor
Key cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout)
v_cache : Tensor
Value cache tensor of shape [num_blocks, block_size, num_heads_k, head_size] (paged layout)
seqlens_k : Tensor
Total KV length per request of shape [batch_size] (int32)
block_table : Tensor
Block mapping table of shape [batch_size, max_num_blocks_per_seq] (int32)
alibi_slopes : Optional[Tensor]
ALiBi slopes tensor, if None then ALiBi is disabled
scale : float
Scaling factor for attention scores (typically 1.0/sqrt(head_size))
)doc"
);
}
}
// namespace infinicore::ops
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment