Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c20e89c
Unverified
Commit
6c20e89c
authored
Jan 21, 2026
by
Pleaplusone
Committed by
GitHub
Jan 21, 2026
Browse files
[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp (#29287)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
85f55c94
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
982 additions
and
323 deletions
+982
-323
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+12
-0
vllm/config/compilation.py
vllm/config/compilation.py
+1
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+318
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+14
-233
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+3
-0
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+6
-0
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+110
-10
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
+518
-80
No files found.
vllm/_aiter_ops.py
View file @
6c20e89c
...
...
@@ -9,6 +9,10 @@ from torch._ops import OpOverload
import
vllm.envs
as
envs
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
from
vllm.v1.attention.ops.rocm_aiter_mla_sparse
import
(
rocm_aiter_sparse_attn_indexer
,
rocm_aiter_sparse_attn_indexer_fake
,
)
_FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -1091,6 +1095,14 @@ class rocm_aiter_ops:
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_sparse_attn_indexer"
,
op_func
=
rocm_aiter_sparse_attn_indexer
,
mutates_args
=
[
"topk_indices_buffer"
],
fake_impl
=
rocm_aiter_sparse_attn_indexer_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
_OPS_REGISTERED
=
True
@
staticmethod
...
...
vllm/config/compilation.py
View file @
6c20e89c
...
...
@@ -611,6 +611,7 @@ class CompilationConfig:
"vllm::gdn_attention_core"
,
"vllm::kda_attention"
,
"vllm::sparse_attn_indexer"
,
"vllm::rocm_aiter_sparse_attn_indexer"
,
]
def
compute_hash
(
self
)
->
str
:
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
0 → 100644
View file @
6c20e89c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom Sparse Attention Indexer layers."""
import
torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.mla.indexer
import
(
DeepseekV32IndexerMetadata
,
)
from
vllm.v1.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
from
vllm.v1.worker.workspace
import
current_workspace_manager
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
logger
=
init_logger
(
__name__
)
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
current_workspace_manager
().
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
torch
.
float8_e4m3fn
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
return
sparse_attn_indexer_fake
(
hidden_states
,
k_cache_prefix
,
kv_cache
,
q_fp8
,
k
,
weights
,
quant_block_size
,
scale_fmt
,
topk_tokens
,
head_dim
,
max_model_len
,
total_seq_lens
,
topk_indices_buffer
,
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
scale_fmt
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager
=
current_workspace_manager
()
k_fp8_full
,
k_scale_full
=
workspace_manager
.
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
logits
=
fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
).
flatten
()),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache
=
kv_cache
.
unsqueeze
(
-
2
)
decode_lens
=
decode_metadata
.
decode_lens
if
decode_metadata
.
requires_padding
:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens
=
pack_seq_triton
(
q_fp8
[:
num_decode_tokens
],
decode_lens
)
else
:
padded_q_fp8_decode_tokens
=
q_fp8
[:
num_decode_tokens
].
reshape
(
decode_lens
.
shape
[
0
],
-
1
,
*
q_fp8
.
shape
[
1
:]
)
# TODO: move and optimize below logic with triton kernels
batch_size
=
padded_q_fp8_decode_tokens
.
shape
[
0
]
next_n
=
padded_q_fp8_decode_tokens
.
shape
[
1
]
assert
batch_size
==
decode_metadata
.
seq_lens
.
shape
[
0
]
num_padded_tokens
=
batch_size
*
next_n
logits
=
fp8_paged_mqa_logits
(
padded_q_fp8_decode_tokens
,
kv_cache
,
weights
[:
num_padded_tokens
],
decode_metadata
.
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[:
num_padded_tokens
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
decode_metadata
.
requires_padding
:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices
=
unpack_seq_triton
(
topk_indices
.
reshape
(
batch_size
,
-
1
,
topk_indices
.
shape
[
-
1
]),
decode_lens
,
)
topk_indices_buffer
[:
num_decode_tokens
,
:
topk_indices
.
shape
[
-
1
]]
=
(
topk_indices
)
return
topk_indices_buffer
def
sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
return
topk_indices_buffer
direct_register_custom_op
(
op_name
=
"sparse_attn_indexer"
,
op_func
=
sparse_attn_indexer
,
mutates_args
=
[
"topk_indices_buffer"
],
fake_impl
=
sparse_attn_indexer_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
@
CustomOp
.
register
(
"sparse_attn_indexer"
)
class
SparseAttnIndexer
(
CustomOp
):
"""Sparse Attention Indexer Custom Op Layer. This layer is extracted as a
separate custom op since it involves heavy custom kernels like `mqa_logits`,
`paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires
specific memory layout or implementation for different hardware backends to
achieve optimal performance.
For now, the default native path will use CUDA backend path. Other platform
may requires add the corresponding Custom Op name `sparse_attn_indexer` to
`custom_ops` in `CompilationConfig` to enable the platform specific path.
"""
def
__init__
(
self
,
k_cache
,
quant_block_size
:
int
,
scale_fmt
:
str
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
max_total_seq_len
:
int
,
topk_indices_buffer
:
torch
.
Tensor
,
):
super
().
__init__
()
self
.
k_cache
=
k_cache
self
.
quant_block_size
=
quant_block_size
self
.
scale_fmt
=
scale_fmt
self
.
topk_tokens
=
topk_tokens
self
.
head_dim
=
head_dim
self
.
max_model_len
=
max_model_len
self
.
max_total_seq_len
=
max_total_seq_len
self
.
topk_indices_buffer
=
topk_indices_buffer
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
if
current_platform
.
is_cuda
():
return
self
.
forward_cuda
(
hidden_states
,
q_fp8
,
k
,
weights
)
elif
current_platform
.
is_rocm
():
return
self
.
forward_hip
(
hidden_states
,
q_fp8
,
k
,
weights
)
else
:
raise
NotImplementedError
(
"SparseAttnIndexer native forward is only implemented for "
"CUDA and ROCm platform."
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
q_fp8
,
k
,
weights
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
def
forward_hip
(
self
,
hidden_states
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
if
rocm_aiter_ops
.
is_enabled
():
return
torch
.
ops
.
vllm
.
rocm_aiter_sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
q_fp8
,
k
,
weights
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
else
:
raise
RuntimeError
(
"Sparse attention indexer ROCm custom op requires ROCm "
"Aiter ops to be enabled."
)
vllm/model_executor/models/deepseek_v2.py
View file @
6c20e89c
...
...
@@ -43,7 +43,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
)
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
...
...
@@ -63,6 +62,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sparse_attn_indexer
import
SparseAttnIndexer
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
@@ -74,16 +74,11 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.attention.backends.mla.indexer
import
(
DeepseekV32IndexerBackend
,
DeepseekV32IndexerMetadata
,
)
from
vllm.v1.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
,
MLAAttentionSpec
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
.interfaces
import
MixtureOfExperts
,
SupportsEagle
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
...
...
@@ -94,11 +89,6 @@ from .utils import (
maybe_prefix
,
)
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
logger
=
init_logger
(
__name__
)
...
...
@@ -599,213 +589,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
return
DeepseekV32IndexerBackend
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
current_workspace_manager
().
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
torch
.
float8_e4m3fn
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
return
sparse_attn_indexer_fake
(
hidden_states
,
k_cache_prefix
,
kv_cache
,
q_fp8
,
k
,
weights
,
quant_block_size
,
scale_fmt
,
topk_tokens
,
head_dim
,
max_model_len
,
total_seq_lens
,
topk_indices_buffer
,
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
scale_fmt
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager
=
current_workspace_manager
()
k_fp8_full
,
k_scale_full
=
workspace_manager
.
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
fp8_mqa_logits_func
=
fp8_mqa_logits
if
current_platform
.
is_rocm
():
from
vllm.v1.attention.ops.rocm_aiter_mla_sparse
import
(
rocm_fp8_mqa_logits
,
)
fp8_mqa_logits_func
=
rocm_fp8_mqa_logits
logits
=
fp8_mqa_logits_func
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
).
flatten
()),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache
=
kv_cache
.
unsqueeze
(
-
2
)
decode_lens
=
decode_metadata
.
decode_lens
if
decode_metadata
.
requires_padding
:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
# [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim]
padded_q_fp8_decode_tokens
=
pack_seq_triton
(
q_fp8
[:
num_decode_tokens
],
decode_lens
)
# [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head]
padded_weights
=
pack_seq_triton
(
weights
[:
num_decode_tokens
],
decode_lens
)
# [bs, 1+next_n, n_head] -> [bs * next_n, n_head]
padded_weights
=
padded_weights
.
flatten
(
0
,
1
)
else
:
padded_q_fp8_decode_tokens
=
q_fp8
[:
num_decode_tokens
].
reshape
(
decode_lens
.
shape
[
0
],
-
1
,
*
q_fp8
.
shape
[
1
:]
)
padded_weights
=
weights
# TODO: move and optimize below logic with triton kernels
batch_size
=
padded_q_fp8_decode_tokens
.
shape
[
0
]
next_n
=
padded_q_fp8_decode_tokens
.
shape
[
1
]
assert
batch_size
==
decode_metadata
.
seq_lens
.
shape
[
0
]
num_padded_tokens
=
batch_size
*
next_n
fp8_paged_mqa_logits_func
=
fp8_paged_mqa_logits
if
current_platform
.
is_rocm
():
from
vllm.v1.attention.ops.rocm_aiter_mla_sparse
import
(
rocm_fp8_paged_mqa_logits
,
)
fp8_paged_mqa_logits_func
=
rocm_fp8_paged_mqa_logits
logits
=
fp8_paged_mqa_logits_func
(
padded_q_fp8_decode_tokens
,
kv_cache
,
padded_weights
[:
num_padded_tokens
],
decode_metadata
.
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[:
num_padded_tokens
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
decode_metadata
.
requires_padding
:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices
=
unpack_seq_triton
(
topk_indices
.
reshape
(
batch_size
,
-
1
,
topk_indices
.
shape
[
-
1
]),
decode_lens
,
)
topk_indices_buffer
[:
num_decode_tokens
,
:
topk_indices
.
shape
[
-
1
]]
=
(
topk_indices
)
return
topk_indices_buffer
def
sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
return
topk_indices_buffer
direct_register_custom_op
(
op_name
=
"sparse_attn_indexer"
,
op_func
=
sparse_attn_indexer
,
mutates_args
=
[
"topk_indices_buffer"
],
fake_impl
=
sparse_attn_indexer_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
class
Indexer
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -870,6 +653,16 @@ class Indexer(nn.Module):
from
vllm.v1.attention.backends.mla.indexer
import
get_max_prefill_buffer_size
self
.
max_total_seq_len
=
get_max_prefill_buffer_size
(
vllm_config
)
self
.
indexer_op
=
SparseAttnIndexer
(
self
.
k_cache
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
qr
:
torch
.
Tensor
,
positions
,
rotary_emb
...
...
@@ -892,6 +685,8 @@ class Indexer(nn.Module):
q_pe
=
q_pe
.
reshape
(
-
1
,
self
.
n_head
,
self
.
rope_dim
)
k_pe
=
k_pe
.
reshape
(
-
1
,
1
,
self
.
rope_dim
)
# `rotary_emb` is shape-preserving; `q_pe` is already
# [num_tokens, n_head, rope_dim].
q
=
torch
.
cat
([
q_pe
,
q_nope
],
dim
=-
1
)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
...
...
@@ -913,21 +708,7 @@ class Indexer(nn.Module):
)
weights
=
weights
.
squeeze
(
-
1
)
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
q_fp8
,
k
,
weights
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
return
self
.
indexer_op
(
hidden_states
,
q_fp8
,
k
,
weights
)
class
DeepseekV2MLAAttention
(
nn
.
Module
):
...
...
vllm/platforms/rocm.py
View file @
6c20e89c
...
...
@@ -480,6 +480,9 @@ class RocmPlatform(Platform):
):
compilation_config
.
custom_ops
.
append
(
"+grouped_topk"
)
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config
.
custom_ops
.
append
(
"+sparse_attn_indexer"
)
@
classmethod
def
verify_model_arch
(
cls
,
model_arch
:
str
)
->
None
:
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
6c20e89c
...
...
@@ -63,6 +63,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
cu_seqlen_ks
:
torch
.
Tensor
cu_seqlen_ke
:
torch
.
Tensor
cu_seq_lens
:
torch
.
Tensor
token_to_seq
:
torch
.
Tensor
total_seq_lens
:
int
token_start
:
int
token_end
:
int
...
...
@@ -234,6 +235,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
token_start
=
query_start_loc_cpu
[
reqs_start
].
item
()
token_end
=
query_start_loc_cpu
[
reqs_end
].
item
()
total_seq_lens
=
seq_lens_cpu
[
reqs_start
:
reqs_end
].
sum
()
seq_idx
=
torch
.
arange
(
0
,
reqs_end
-
reqs_start
,
dtype
=
torch
.
int32
)
token_to_seq
=
torch
.
repeat_interleave
(
seq_idx
,
seq_lens_cpu
[
reqs_start
:
reqs_end
]
).
to
(
self
.
device
)
assert
total_seq_lens
<=
self
.
max_prefill_buffer_size
cu_seq_lens
=
(
torch
.
cat
(
...
...
@@ -249,6 +254,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cu_seqlen_ks
=
cu_seqlen_ks
,
cu_seqlen_ke
=
cu_seqlen_ke
,
cu_seq_lens
=
cu_seq_lens
,
token_to_seq
=
token_to_seq
,
total_seq_lens
=
total_seq_lens
,
block_table
=
block_table
[
reqs_start
:
reqs_end
],
token_start
=
token_start
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
View file @
6c20e89c
...
...
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl
,
get_mla_dims
,
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
...
...
@@ -33,6 +34,48 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
fetch_id_to_ragged_kernel
(
in_tensor_ptr
,
# [num_seq, topk]
cumsum_ptr
,
# [num_seq + 1]
out_tensor_ptr
,
# [max_num_seq * topk]
in_tensor_ptr_stride
,
TOPK
:
tl
.
constexpr
,
TOKEN_NUM
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
seq_id
=
tl
.
program_id
(
0
)
block_id
=
tl
.
program_id
(
1
)
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
token_start
=
tl
.
load
(
cumsum_ptr
+
seq_id
)
token_end
=
tl
.
load
(
cumsum_ptr
+
seq_id
+
1
)
token_num
=
token_end
-
token_start
row_offset
=
block_id
*
BLOCK_SIZE
if
row_offset
>=
token_num
:
return
in_tensor_offset
=
seq_id
*
in_tensor_ptr_stride
+
row_offset
+
offset
in_tensor_mask
=
(
row_offset
+
offset
)
<
TOPK
in_tensor_val
=
tl
.
load
(
in_tensor_ptr
+
in_tensor_offset
,
mask
=
in_tensor_mask
)
out_tensor_offset
=
token_start
+
row_offset
+
offset
out_tensor_mask
=
(
out_tensor_offset
<
token_end
)
&
in_tensor_mask
tl
.
store
(
out_tensor_ptr
+
out_tensor_offset
,
in_tensor_val
,
mask
=
out_tensor_mask
)
def
fetch_id_to_ragged_triton
(
in_tensor
:
torch
.
Tensor
,
cumsum
:
torch
.
Tensor
,
out_tensor
:
torch
.
Tensor
,
topk
):
num_tokens
=
in_tensor
.
size
(
0
)
block_size
=
64
num_block_per_row
=
triton
.
cdiv
(
topk
,
block_size
)
grid
=
(
num_tokens
,
num_block_per_row
,
)
fetch_id_to_ragged_kernel
[
grid
](
in_tensor
,
cumsum
,
out_tensor
,
in_tensor
.
stride
(
0
),
topk
,
num_tokens
,
block_size
)
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
...
...
@@ -83,6 +126,13 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
block_table
:
torch
.
Tensor
req_id_per_token
:
torch
.
Tensor
qo_indptr
:
torch
.
Tensor
paged_kv_last_page_len
:
torch
.
Tensor
paged_kv_indices
:
torch
.
Tensor
paged_kv_indptr
:
torch
.
Tensor
paged_kv_indptr_rest
:
torch
.
Tensor
block_size
:
int
=
1
topk_tokens
:
int
=
2048
...
...
@@ -91,7 +141,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
class
ROCMAiterMLASparseMetadataBuilder
(
AttentionMetadataBuilder
[
ROCMAiterMLASparseMetadata
]
):
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
NEVER
_
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
NEVER
def
__init__
(
self
,
...
...
@@ -104,6 +154,7 @@ class ROCMAiterMLASparseMetadataBuilder(
self
.
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
self
.
device
=
device
max_num_batched_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
...
...
@@ -124,6 +175,23 @@ class ROCMAiterMLASparseMetadataBuilder(
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
qo_indptr
=
torch
.
arange
(
0
,
max_num_batched_tokens
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
paged_kv_last_page_len
=
torch
.
ones
(
max_num_batched_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
# These two needs to be calculated in runtime,
# but we still needs to prepare the buffer
self
.
paged_kv_indices
=
torch
.
zeros
(
[
max_num_batched_tokens
*
self
.
topk_tokens
],
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
paged_kv_indptr
=
torch
.
zeros
(
[
max_num_batched_tokens
+
1
],
dtype
=
torch
.
int32
,
device
=
device
)
def
build
(
self
,
...
...
@@ -142,7 +210,15 @@ class ROCMAiterMLASparseMetadataBuilder(
self
.
req_id_per_token_buffer
[:
req_id_per_token
.
shape
[
0
]].
copy_
(
torch
.
from_numpy
(
req_id_per_token
),
non_blocking
=
True
)
self
.
paged_kv_indices
.
fill_
(
0
)
self
.
paged_kv_indptr
.
fill_
(
0
)
req_id_per_token
=
self
.
req_id_per_token_buffer
[:
num_tokens
]
qo_indptr
=
self
.
qo_indptr
[:
num_tokens
+
1
]
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_tokens
]
paged_kv_indices
=
self
.
paged_kv_indices
[:
num_tokens
*
self
.
topk_tokens
]
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
num_tokens
+
1
]
paged_kv_indptr_rest
=
self
.
paged_kv_indptr
[
num_tokens
+
1
:]
metadata
=
ROCMAiterMLASparseMetadata
(
num_reqs
=
common_attn_metadata
.
num_reqs
,
...
...
@@ -155,6 +231,11 @@ class ROCMAiterMLASparseMetadataBuilder(
req_id_per_token
=
req_id_per_token
,
block_size
=
self
.
kv_cache_spec
.
block_size
,
topk_tokens
=
self
.
topk_tokens
,
qo_indptr
=
qo_indptr
,
paged_kv_last_page_len
=
paged_kv_last_page_len
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_indptr
=
paged_kv_indptr
,
paged_kv_indptr_rest
=
paged_kv_indptr_rest
,
)
return
metadata
...
...
@@ -226,20 +307,39 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
def
_forward_bf16_kv
(
self
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
# [sq, heads, d_qk]
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
# [blocks, heads, d_qk]
topk_indices
:
torch
.
Tensor
,
# [sq, topk]
attn_metadata
:
ROCMAiterMLASparseMetadata
,
)
->
torch
.
Tensor
:
num_tokens
=
q
.
shape
[
0
]
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
view
(
-
1
,
1
,
kv_c_and_k_pe_cache
.
shape
[
-
1
]
output
=
torch
.
empty
(
[
num_tokens
,
self
.
num_heads
,
self
.
kv_lora_rank
],
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
seq_len
=
(
topk_indices
!=
-
1
).
sum
(
dim
=-
1
)
torch
.
cumsum
(
seq_len
,
dim
=
0
,
out
=
attn_metadata
.
paged_kv_indptr
[
1
:])
attn_metadata
.
paged_kv_indptr_rest
.
fill_
(
attn_metadata
.
paged_kv_indptr
[
-
1
])
fetch_id_to_ragged_triton
(
topk_indices
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
topk_tokens
,
)
rocm_aiter_ops
.
mla_decode_fwd
(
q
,
kv_c_and_k_pe_cache
,
output
,
self
.
scale
,
attn_metadata
.
qo_indptr
,
1
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_len
,
)
topk_indices
=
topk_indices
.
view
(
num_tokens
,
1
,
-
1
)
output
=
reference_mla_sparse_prefill
(
q
,
kv_c_and_k_pe_cache
,
topk_indices
,
self
.
softmax_scale
,
512
)[
0
]
return
output
[:,
:
self
.
num_heads
,
:]
def
forward
(
...
...
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
View file @
6c20e89c
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment