Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
656944ac
Commit
656944ac
authored
Mar 21, 2026
by
yangql
Browse files
增加triton的indexer的kcahche读写操作
parent
12b5bcb1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
273 additions
and
7 deletions
+273
-7
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+23
-7
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
+250
-0
No files found.
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
656944ac
...
...
@@ -9,12 +9,14 @@ from vllm.forward_context import get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
get_gcn_arch_name
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.mla.indexer
import
(
DeepseekV32IndexerMetadata
,
)
from
vllm.v1.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
from
vllm.v1.attention.ops.rocm_aiter_mla_sparse
import
indexer_k_bf16_cache_triton
,
cp_gather_indexer_k_bf16_cache_triton
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
lightop
import
op
,
gemmopt
...
...
@@ -73,7 +75,8 @@ def sparse_attn_indexer(
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_tokens
=
slot_mapping
.
shape
[
0
]
k
=
k
[:
num_tokens
]
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
ops
.
indexer_k_quant_and_cache
(
k
,
...
...
@@ -82,7 +85,12 @@ def sparse_attn_indexer(
quant_block_size
,
scale_fmt
,
)
else
:
indexer_k_bf16_cache_triton
(
k
,
kv_cache
,
slot_mapping
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
...
...
@@ -90,7 +98,7 @@ def sparse_attn_indexer(
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager
=
current_workspace_manager
()
k_fp8_full
,
k_scale_full
=
workspace_manager
.
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnA
rch
N
ame
.
split
(
':'
)[
0
]
==
"gfx938"
else
k
.
dtype
,),
((
total_seq_lens
,
head_dim
),
fp8_dtype
if
not
current_platform
.
is_rocm
()
or
get_gcn_a
rch
_n
ame
==
"gfx938"
else
k
.
dtype
,),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
for
chunk
in
prefill_metadata
.
chunks
:
...
...
@@ -112,7 +120,7 @@ def sparse_attn_indexer(
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
elif
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnA
rch
N
ame
.
split
(
':'
)[
0
]
==
"gfx938"
:
elif
get_gcn_a
rch
_n
ame
==
"gfx938"
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
ops
.
cp_gather_indexer_k_quant_cache
(
...
...
@@ -136,14 +144,22 @@ def sparse_attn_indexer(
True
)
else
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
cp_gather_indexer_k_bf16_cache_triton
(
kv_cache
,
k_fp8
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
logits
=
op
.
mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
k
,
k
_fp8
,
weights
[
chunk
.
token_start
:
chunk
.
token_end
].
to
(
torch
.
float32
),
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
].
shape
[
0
],
k
.
shape
[
0
],
k
_fp8
.
shape
[
0
],
q_fp8
.
shape
[
1
],
q_fp8
.
shape
[
2
],
None
,
...
...
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
View file @
656944ac
...
...
@@ -216,6 +216,256 @@ def cp_gather_indexer_k_quant_cache_triton(
head_tile_size
,
)
@
triton
.
jit
def
_indexer_k_bf16_cache_kernel
(
k_ptr
,
# [num_tokens, head_dim] (bf16)
kv_cache_ptr
,
# [n_blks, block_size, head_dim] (bf16)
slot_mapping_ptr
,
# [num_tokens]
kv_cache_stride
,
# KV Cache 第一维的stride
block_size
:
tl
.
constexpr
,
num_tokens
:
tl
.
constexpr
,
head_dim
:
tl
.
constexpr
,
LAYOUT
:
tl
.
constexpr
,
BLOCK_TILE_SIZE
:
tl
.
constexpr
,
HEAD_TILE_SIZE
:
tl
.
constexpr
,
):
"""
Triton 核函数:将 BF16 类型的 K 张量写入 KV Cache(
"""
tid
=
tl
.
program_id
(
0
)
# 边界检查:超出 token 范围直接返回
if
tid
>=
num_tokens
:
return
# 定义头维度索引偏移(覆盖整个 head_dim)
offset
=
tl
.
arange
(
0
,
head_dim
)
# 计算输入 K 张量的源指针偏移
src_ptr
=
k_ptr
+
tid
*
head_dim
# 加载当前 token 对应的 cache slot ID
slot_id
=
tl
.
load
(
slot_mapping_ptr
+
tid
)
# 无效 slot(-1)直接返回
if
slot_id
<
0
:
return
# 计算 block ID 和块内偏移
block_id
=
slot_id
//
block_size
block_offset
=
slot_id
%
block_size
# 分块相关的偏移计算(兼容 SHUFFLE 布局)
tile_block_id
=
block_offset
//
BLOCK_TILE_SIZE
tile_block_offset
=
block_offset
%
BLOCK_TILE_SIZE
# 根据布局计算 KV Cache 的目标指针偏移
if
LAYOUT
==
"SHUFFLE"
:
# SHUFFLE 布局的偏移计算
tile_offset
=
(
offset
//
HEAD_TILE_SIZE
*
BLOCK_TILE_SIZE
*
HEAD_TILE_SIZE
+
offset
%
HEAD_TILE_SIZE
)
dst_ptr
=
(
kv_cache_ptr
+
block_id
*
kv_cache_stride
+
tile_block_id
*
BLOCK_TILE_SIZE
*
head_dim
+
tile_block_offset
*
HEAD_TILE_SIZE
)
else
:
# NHD 标准布局
tile_offset
=
offset
dst_ptr
=
(
kv_cache_ptr
+
block_id
*
kv_cache_stride
+
block_offset
*
head_dim
)
val
=
tl
.
load
(
src_ptr
+
offset
)
tl
.
store
(
dst_ptr
+
tile_offset
,
val
)
def
indexer_k_bf16_cache_triton
(
k
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
# [num_blocks, block_size, head_dim] (bf16)
slot_mapping
:
torch
.
Tensor
,
block_tile_size
=
16
,
head_tile_size
=
16
,
):
"""
将 BF16 类型的 K 张量写入 BF16 类型的 KV Cache
Args:
k: 输入 K 张量 [num_tokens, head_dim] (bf16)
kv_cache: KV Cache 张量 [num_blocks, block_size, head_dim] (bf16)
slot_mapping: token 到 cache slot 的映射 [num_tokens]
block_tile_size: 块分块大小
head_tile_size: 头维度分块大小
"""
# 输入类型校验
assert
k
.
dtype
==
torch
.
bfloat16
,
"k 必须是 bf16 类型"
assert
kv_cache
.
dtype
==
torch
.
bfloat16
,
"kv_cache 必须是 bf16 类型"
# 解析张量维度
num_blocks
=
kv_cache
.
shape
[
0
]
block_size
=
kv_cache
.
shape
[
1
]
head_dim
=
k
.
shape
[
-
1
]
num_tokens
=
slot_mapping
.
shape
[
0
]
# 验证维度合法性
assert
kv_cache
.
shape
[
2
]
==
head_dim
,
"kv_cache 的 head_dim 必须与 k 一致"
# 重塑 KV Cache 为二维(便于指针计算)
kv_cache_2d
=
kv_cache
.
view
(
num_blocks
,
-
1
)
# [num_blocks, block_size * head_dim]
# 调整 head_tile_size(兼容原逻辑,按字节数归一化)
head_tile_size
=
head_tile_size
//
kv_cache
.
element_size
()
# 配置 Triton 核函数的 grid(每个 token 一个 program)
grid
=
(
num_tokens
,)
_indexer_k_bf16_cache_kernel
[
grid
](
k
,
kv_cache_2d
,
slot_mapping
,
kv_cache_2d
.
stride
(
0
),
block_size
,
num_tokens
,
head_dim
,
"NHD"
,
# 布局类型
block_tile_size
,
head_tile_size
,
)
@
triton
.
jit
def
_cp_gather_indexer_k_bf16_cache_kernel
(
kv_cache_ptr
,
# [num_blocks, block_size * head_dim] (bf16)
k_bf16_ptr
,
# [num_tokens, head_dim] (bf16)
block_table_ptr
,
cu_seq_lens_ptr
,
block_size
:
tl
.
constexpr
,
batch_size
:
tl
.
constexpr
,
num_blocks_per_seq
:
tl
.
constexpr
,
kv_cache_stride
:
tl
.
constexpr
,
head_dim
:
tl
.
constexpr
,
num_tokens
:
tl
.
constexpr
,
BLOCK_TILE_SIZE
:
tl
.
constexpr
,
HEAD_TILE_SIZE
:
tl
.
constexpr
,
):
"""
Triton 核函数 BF16 K Cache 收集
"""
token_idx
=
tl
.
program_id
(
0
)
# 边界检查:超出 token 范围直接返回
if
token_idx
>=
num_tokens
:
return
# 定义头维度索引偏移(覆盖整个 head_dim)
head_offset
=
tl
.
arange
(
0
,
head_dim
)
batch_idx
=
tl
.
full
((),
-
1
,
dtype
=
tl
.
int32
)
# 遍历所有 batch(Triton 支持有限循环,需固定循环次数)
for
b
in
tl
.
static_range
(
batch_size
):
# 加载当前 batch 的序列起始/结束位置
seq_start
=
tl
.
load
(
cu_seq_lens_ptr
+
b
)
seq_end
=
tl
.
load
(
cu_seq_lens_ptr
+
b
+
1
)
# 条件判断:当前 token 是否属于该 batch
is_in_batch
=
(
token_idx
>=
seq_start
)
&
(
token_idx
<
seq_end
)
# 条件赋值:如果属于该 batch,更新 batch_idx(替代 break)
batch_idx
=
tl
.
where
(
is_in_batch
,
b
,
batch_idx
)
# 无效的 batch ID(token 不在任何序列中),直接返回
if
batch_idx
==
-
1
:
return
# --------------------------
# 计算序列内偏移和 block 索引
# --------------------------
# token 在所属序列内的相对偏移
seq_start
=
tl
.
load
(
cu_seq_lens_ptr
+
batch_idx
)
inbatch_seq_idx
=
token_idx
-
seq_start
# 计算该 token 对应的 block 索引(block_table 中的位置)
block_table_id
=
inbatch_seq_idx
//
block_size
# 边界检查:block 索引超出范围则返回
if
block_table_id
>=
num_blocks_per_seq
:
return
# 计算 block_table 中的内存偏移并加载 block ID
block_table_offset
=
batch_idx
*
num_blocks_per_seq
+
block_table_id
block_id
=
tl
.
load
(
block_table_ptr
+
block_table_offset
)
# 计算 token 在 block 内的偏移
block_offset
=
inbatch_seq_idx
%
block_size
# --------------------------
# 计算内存偏移
# --------------------------
# KV Cache 源偏移:block_id * 块步长 + 块内偏移 * head_dim
src_block_offset
=
block_id
*
kv_cache_stride
src_inblock_offset
=
src_block_offset
+
block_offset
*
head_dim
# 输出张量目标偏移
dst_inblock_offset
=
token_idx
*
head_dim
src_ptr
=
kv_cache_ptr
+
src_inblock_offset
+
head_offset
val
=
tl
.
load
(
src_ptr
)
dst_ptr
=
k_bf16_ptr
+
dst_inblock_offset
+
head_offset
tl
.
store
(
dst_ptr
,
val
)
def
cp_gather_indexer_k_bf16_cache_triton
(
k_cache
:
torch
.
Tensor
,
# [num_blocks, block_size, head_dim] (bf16)
k_bf16
:
torch
.
Tensor
,
# [num_tokens, head_dim] (bf16)
block_table
:
torch
.
Tensor
,
# [batch_size, num_blocks_per_seq]
cu_seq_lens
:
torch
.
Tensor
,
# [batch_size + 1]
block_tile_size
:
int
=
16
,
head_tile_size
:
int
=
16
,
):
"""
BF16 K Cache 收集算子
Args:
k_cache: K缓存张量 [num_blocks, block_size, head_dim] (bf16)
k_bf16: 输出张量 [num_tokens, head_dim] (bf16)
block_table: 块表 [batch_size, num_blocks_per_seq]
cu_seq_lens: 序列长度累积数组 [batch_size + 1]
block_tile_size: 块分块大小
head_tile_size: 头维度分块大小
"""
# 输入类型校验
assert
k_cache
.
dtype
==
torch
.
bfloat16
,
"k_cache 必须是 bf16 类型"
assert
k_bf16
.
dtype
==
torch
.
bfloat16
,
"k_bf16 必须是 bf16 类型"
# 解析维度参数
num_tokens
=
k_bf16
.
size
(
0
)
block_size
=
k_cache
.
size
(
1
)
head_dim
=
k_bf16
.
shape
[
-
1
]
num_blocks
=
k_cache
.
shape
[
0
]
batch_size
=
block_table
.
size
(
0
)
num_blocks_per_seq
=
block_table
.
size
(
1
)
# 重塑缓存张量(便于指针计算)
k_cache_2d
=
k_cache
.
view
(
num_blocks
,
-
1
)
# [num_blocks, block_size * head_dim]
# 配置 Triton 核函数的 grid(每个 token 一个 program)
grid
=
(
num_tokens
,)
_cp_gather_indexer_k_bf16_cache_kernel
[
grid
](
k_cache_2d
,
k_bf16
,
block_table
,
cu_seq_lens
,
block_size
,
batch_size
,
num_blocks_per_seq
,
k_cache_2d
.
stride
(
0
),
# kv_cache stride (block维度)
head_dim
,
num_tokens
,
block_tile_size
,
head_tile_size
,
)
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
def
fp8_paged_mqa_logits_torch
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment