Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c20e89c
Unverified
Commit
6c20e89c
authored
Jan 21, 2026
by
Pleaplusone
Committed by
GitHub
Jan 21, 2026
Browse files
[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp (#29287)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
85f55c94
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
982 additions
and
323 deletions
+982
-323
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+12
-0
vllm/config/compilation.py
vllm/config/compilation.py
+1
-0
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+318
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+14
-233
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+3
-0
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+6
-0
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+110
-10
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
+518
-80
No files found.
vllm/_aiter_ops.py
View file @
6c20e89c
...
@@ -9,6 +9,10 @@ from torch._ops import OpOverload
...
@@ -9,6 +9,10 @@ from torch._ops import OpOverload
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
from
vllm.v1.attention.ops.rocm_aiter_mla_sparse
import
(
rocm_aiter_sparse_attn_indexer
,
rocm_aiter_sparse_attn_indexer_fake
,
)
_FP8_DTYPE
=
current_platform
.
fp8_dtype
()
_FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -1091,6 +1095,14 @@ class rocm_aiter_ops:
...
@@ -1091,6 +1095,14 @@ class rocm_aiter_ops:
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_sparse_attn_indexer"
,
op_func
=
rocm_aiter_sparse_attn_indexer
,
mutates_args
=
[
"topk_indices_buffer"
],
fake_impl
=
rocm_aiter_sparse_attn_indexer_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
_OPS_REGISTERED
=
True
_OPS_REGISTERED
=
True
@
staticmethod
@
staticmethod
...
...
vllm/config/compilation.py
View file @
6c20e89c
...
@@ -611,6 +611,7 @@ class CompilationConfig:
...
@@ -611,6 +611,7 @@ class CompilationConfig:
"vllm::gdn_attention_core"
,
"vllm::gdn_attention_core"
,
"vllm::kda_attention"
,
"vllm::kda_attention"
,
"vllm::sparse_attn_indexer"
,
"vllm::sparse_attn_indexer"
,
"vllm::rocm_aiter_sparse_attn_indexer"
,
]
]
def
compute_hash
(
self
)
->
str
:
def
compute_hash
(
self
)
->
str
:
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
0 → 100644
View file @
6c20e89c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom Sparse Attention Indexer layers."""
import
torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.mla.indexer
import
(
DeepseekV32IndexerMetadata
,
)
from
vllm.v1.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
from
vllm.v1.worker.workspace
import
current_workspace_manager
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
logger
=
init_logger
(
__name__
)
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
current_workspace_manager
().
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
torch
.
float8_e4m3fn
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
return
sparse_attn_indexer_fake
(
hidden_states
,
k_cache_prefix
,
kv_cache
,
q_fp8
,
k
,
weights
,
quant_block_size
,
scale_fmt
,
topk_tokens
,
head_dim
,
max_model_len
,
total_seq_lens
,
topk_indices_buffer
,
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
scale_fmt
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager
=
current_workspace_manager
()
k_fp8_full
,
k_scale_full
=
workspace_manager
.
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
logits
=
fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
).
flatten
()),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache
=
kv_cache
.
unsqueeze
(
-
2
)
decode_lens
=
decode_metadata
.
decode_lens
if
decode_metadata
.
requires_padding
:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens
=
pack_seq_triton
(
q_fp8
[:
num_decode_tokens
],
decode_lens
)
else
:
padded_q_fp8_decode_tokens
=
q_fp8
[:
num_decode_tokens
].
reshape
(
decode_lens
.
shape
[
0
],
-
1
,
*
q_fp8
.
shape
[
1
:]
)
# TODO: move and optimize below logic with triton kernels
batch_size
=
padded_q_fp8_decode_tokens
.
shape
[
0
]
next_n
=
padded_q_fp8_decode_tokens
.
shape
[
1
]
assert
batch_size
==
decode_metadata
.
seq_lens
.
shape
[
0
]
num_padded_tokens
=
batch_size
*
next_n
logits
=
fp8_paged_mqa_logits
(
padded_q_fp8_decode_tokens
,
kv_cache
,
weights
[:
num_padded_tokens
],
decode_metadata
.
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[:
num_padded_tokens
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
decode_metadata
.
requires_padding
:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices
=
unpack_seq_triton
(
topk_indices
.
reshape
(
batch_size
,
-
1
,
topk_indices
.
shape
[
-
1
]),
decode_lens
,
)
topk_indices_buffer
[:
num_decode_tokens
,
:
topk_indices
.
shape
[
-
1
]]
=
(
topk_indices
)
return
topk_indices_buffer
def
sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
return
topk_indices_buffer
direct_register_custom_op
(
op_name
=
"sparse_attn_indexer"
,
op_func
=
sparse_attn_indexer
,
mutates_args
=
[
"topk_indices_buffer"
],
fake_impl
=
sparse_attn_indexer_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
@
CustomOp
.
register
(
"sparse_attn_indexer"
)
class
SparseAttnIndexer
(
CustomOp
):
"""Sparse Attention Indexer Custom Op Layer. This layer is extracted as a
separate custom op since it involves heavy custom kernels like `mqa_logits`,
`paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires
specific memory layout or implementation for different hardware backends to
achieve optimal performance.
For now, the default native path will use CUDA backend path. Other platform
may requires add the corresponding Custom Op name `sparse_attn_indexer` to
`custom_ops` in `CompilationConfig` to enable the platform specific path.
"""
def
__init__
(
self
,
k_cache
,
quant_block_size
:
int
,
scale_fmt
:
str
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
max_total_seq_len
:
int
,
topk_indices_buffer
:
torch
.
Tensor
,
):
super
().
__init__
()
self
.
k_cache
=
k_cache
self
.
quant_block_size
=
quant_block_size
self
.
scale_fmt
=
scale_fmt
self
.
topk_tokens
=
topk_tokens
self
.
head_dim
=
head_dim
self
.
max_model_len
=
max_model_len
self
.
max_total_seq_len
=
max_total_seq_len
self
.
topk_indices_buffer
=
topk_indices_buffer
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
if
current_platform
.
is_cuda
():
return
self
.
forward_cuda
(
hidden_states
,
q_fp8
,
k
,
weights
)
elif
current_platform
.
is_rocm
():
return
self
.
forward_hip
(
hidden_states
,
q_fp8
,
k
,
weights
)
else
:
raise
NotImplementedError
(
"SparseAttnIndexer native forward is only implemented for "
"CUDA and ROCm platform."
)
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
q_fp8
,
k
,
weights
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
def
forward_hip
(
self
,
hidden_states
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
):
if
rocm_aiter_ops
.
is_enabled
():
return
torch
.
ops
.
vllm
.
rocm_aiter_sparse_attn_indexer
(
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
q_fp8
,
k
,
weights
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
else
:
raise
RuntimeError
(
"Sparse attention indexer ROCm custom op requires ROCm "
"Aiter ops to be enabled."
)
vllm/model_executor/models/deepseek_v2.py
View file @
6c20e89c
...
@@ -43,7 +43,6 @@ from vllm.distributed import (
...
@@ -43,7 +43,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
...
@@ -63,6 +62,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...
@@ -63,6 +62,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sparse_attn_indexer
import
SparseAttnIndexer
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
@@ -74,16 +74,11 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -74,16 +74,11 @@ from vllm.model_executor.model_loader.weight_utils import (
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.model_executor.models.utils
import
sequence_parallel_chunk
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.deep_gemm
import
fp8_mqa_logits
,
fp8_paged_mqa_logits
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.attention.backends.mla.indexer
import
(
from
vllm.v1.attention.backends.mla.indexer
import
(
DeepseekV32IndexerBackend
,
DeepseekV32IndexerBackend
,
DeepseekV32IndexerMetadata
,
)
)
from
vllm.v1.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
,
MLAAttentionSpec
from
vllm.v1.kv_cache_interface
import
KVCacheSpec
,
MLAAttentionSpec
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
.interfaces
import
MixtureOfExperts
,
SupportsEagle
,
SupportsLoRA
,
SupportsPP
from
.interfaces
import
MixtureOfExperts
,
SupportsEagle
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
from
.utils
import
(
...
@@ -94,11 +89,6 @@ from .utils import (
...
@@ -94,11 +89,6 @@ from .utils import (
maybe_prefix
,
maybe_prefix
,
)
)
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -599,213 +589,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
...
@@ -599,213 +589,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
return
DeepseekV32IndexerBackend
return
DeepseekV32IndexerBackend
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
# Reserve workspace for indexer during profiling run
current_workspace_manager
().
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
torch
.
float8_e4m3fn
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
return
sparse_attn_indexer_fake
(
hidden_states
,
k_cache_prefix
,
kv_cache
,
q_fp8
,
k
,
weights
,
quant_block_size
,
scale_fmt
,
topk_tokens
,
head_dim
,
max_model_len
,
total_seq_lens
,
topk_indices_buffer
,
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
scale_fmt
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager
=
current_workspace_manager
()
k_fp8_full
,
k_scale_full
=
workspace_manager
.
get_simultaneous
(
((
total_seq_lens
,
head_dim
),
fp8_dtype
),
((
total_seq_lens
,
4
),
torch
.
uint8
),
)
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
k_fp8_full
[:
chunk
.
total_seq_lens
]
k_scale
=
k_scale_full
[:
chunk
.
total_seq_lens
]
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
fp8_mqa_logits_func
=
fp8_mqa_logits
if
current_platform
.
is_rocm
():
from
vllm.v1.attention.ops.rocm_aiter_mla_sparse
import
(
rocm_fp8_mqa_logits
,
)
fp8_mqa_logits_func
=
rocm_fp8_mqa_logits
logits
=
fp8_mqa_logits_func
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
).
flatten
()),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache
=
kv_cache
.
unsqueeze
(
-
2
)
decode_lens
=
decode_metadata
.
decode_lens
if
decode_metadata
.
requires_padding
:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
# [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim]
padded_q_fp8_decode_tokens
=
pack_seq_triton
(
q_fp8
[:
num_decode_tokens
],
decode_lens
)
# [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head]
padded_weights
=
pack_seq_triton
(
weights
[:
num_decode_tokens
],
decode_lens
)
# [bs, 1+next_n, n_head] -> [bs * next_n, n_head]
padded_weights
=
padded_weights
.
flatten
(
0
,
1
)
else
:
padded_q_fp8_decode_tokens
=
q_fp8
[:
num_decode_tokens
].
reshape
(
decode_lens
.
shape
[
0
],
-
1
,
*
q_fp8
.
shape
[
1
:]
)
padded_weights
=
weights
# TODO: move and optimize below logic with triton kernels
batch_size
=
padded_q_fp8_decode_tokens
.
shape
[
0
]
next_n
=
padded_q_fp8_decode_tokens
.
shape
[
1
]
assert
batch_size
==
decode_metadata
.
seq_lens
.
shape
[
0
]
num_padded_tokens
=
batch_size
*
next_n
fp8_paged_mqa_logits_func
=
fp8_paged_mqa_logits
if
current_platform
.
is_rocm
():
from
vllm.v1.attention.ops.rocm_aiter_mla_sparse
import
(
rocm_fp8_paged_mqa_logits
,
)
fp8_paged_mqa_logits_func
=
rocm_fp8_paged_mqa_logits
logits
=
fp8_paged_mqa_logits_func
(
padded_q_fp8_decode_tokens
,
kv_cache
,
padded_weights
[:
num_padded_tokens
],
decode_metadata
.
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
)
num_rows
=
logits
.
shape
[
0
]
topk_indices
=
topk_indices_buffer
[:
num_padded_tokens
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
decode_metadata
.
requires_padding
:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices
=
unpack_seq_triton
(
topk_indices
.
reshape
(
batch_size
,
-
1
,
topk_indices
.
shape
[
-
1
]),
decode_lens
,
)
topk_indices_buffer
[:
num_decode_tokens
,
:
topk_indices
.
shape
[
-
1
]]
=
(
topk_indices
)
return
topk_indices_buffer
def
sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
return
topk_indices_buffer
direct_register_custom_op
(
op_name
=
"sparse_attn_indexer"
,
op_func
=
sparse_attn_indexer
,
mutates_args
=
[
"topk_indices_buffer"
],
fake_impl
=
sparse_attn_indexer_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
class
Indexer
(
nn
.
Module
):
class
Indexer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -870,6 +653,16 @@ class Indexer(nn.Module):
...
@@ -870,6 +653,16 @@ class Indexer(nn.Module):
from
vllm.v1.attention.backends.mla.indexer
import
get_max_prefill_buffer_size
from
vllm.v1.attention.backends.mla.indexer
import
get_max_prefill_buffer_size
self
.
max_total_seq_len
=
get_max_prefill_buffer_size
(
vllm_config
)
self
.
max_total_seq_len
=
get_max_prefill_buffer_size
(
vllm_config
)
self
.
indexer_op
=
SparseAttnIndexer
(
self
.
k_cache
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
def
forward
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
qr
:
torch
.
Tensor
,
positions
,
rotary_emb
self
,
hidden_states
:
torch
.
Tensor
,
qr
:
torch
.
Tensor
,
positions
,
rotary_emb
...
@@ -892,6 +685,8 @@ class Indexer(nn.Module):
...
@@ -892,6 +685,8 @@ class Indexer(nn.Module):
q_pe
=
q_pe
.
reshape
(
-
1
,
self
.
n_head
,
self
.
rope_dim
)
q_pe
=
q_pe
.
reshape
(
-
1
,
self
.
n_head
,
self
.
rope_dim
)
k_pe
=
k_pe
.
reshape
(
-
1
,
1
,
self
.
rope_dim
)
k_pe
=
k_pe
.
reshape
(
-
1
,
1
,
self
.
rope_dim
)
# `rotary_emb` is shape-preserving; `q_pe` is already
# [num_tokens, n_head, rope_dim].
q
=
torch
.
cat
([
q_pe
,
q_nope
],
dim
=-
1
)
q
=
torch
.
cat
([
q_pe
,
q_nope
],
dim
=-
1
)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
...
@@ -913,21 +708,7 @@ class Indexer(nn.Module):
...
@@ -913,21 +708,7 @@ class Indexer(nn.Module):
)
)
weights
=
weights
.
squeeze
(
-
1
)
weights
=
weights
.
squeeze
(
-
1
)
return
torch
.
ops
.
vllm
.
sparse_attn_indexer
(
return
self
.
indexer_op
(
hidden_states
,
q_fp8
,
k
,
weights
)
hidden_states
,
self
.
k_cache
.
prefix
,
self
.
k_cache
.
kv_cache
[
0
],
q_fp8
,
k
,
weights
,
self
.
quant_block_size
,
self
.
scale_fmt
,
self
.
topk_tokens
,
self
.
head_dim
,
self
.
max_model_len
,
self
.
max_total_seq_len
,
self
.
topk_indices_buffer
,
)
class
DeepseekV2MLAAttention
(
nn
.
Module
):
class
DeepseekV2MLAAttention
(
nn
.
Module
):
...
...
vllm/platforms/rocm.py
View file @
6c20e89c
...
@@ -480,6 +480,9 @@ class RocmPlatform(Platform):
...
@@ -480,6 +480,9 @@ class RocmPlatform(Platform):
):
):
compilation_config
.
custom_ops
.
append
(
"+grouped_topk"
)
compilation_config
.
custom_ops
.
append
(
"+grouped_topk"
)
# Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config
.
custom_ops
.
append
(
"+sparse_attn_indexer"
)
@
classmethod
@
classmethod
def
verify_model_arch
(
cls
,
model_arch
:
str
)
->
None
:
def
verify_model_arch
(
cls
,
model_arch
:
str
)
->
None
:
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
6c20e89c
...
@@ -63,6 +63,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
...
@@ -63,6 +63,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
cu_seqlen_ks
:
torch
.
Tensor
cu_seqlen_ks
:
torch
.
Tensor
cu_seqlen_ke
:
torch
.
Tensor
cu_seqlen_ke
:
torch
.
Tensor
cu_seq_lens
:
torch
.
Tensor
cu_seq_lens
:
torch
.
Tensor
token_to_seq
:
torch
.
Tensor
total_seq_lens
:
int
total_seq_lens
:
int
token_start
:
int
token_start
:
int
token_end
:
int
token_end
:
int
...
@@ -234,6 +235,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...
@@ -234,6 +235,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
token_start
=
query_start_loc_cpu
[
reqs_start
].
item
()
token_start
=
query_start_loc_cpu
[
reqs_start
].
item
()
token_end
=
query_start_loc_cpu
[
reqs_end
].
item
()
token_end
=
query_start_loc_cpu
[
reqs_end
].
item
()
total_seq_lens
=
seq_lens_cpu
[
reqs_start
:
reqs_end
].
sum
()
total_seq_lens
=
seq_lens_cpu
[
reqs_start
:
reqs_end
].
sum
()
seq_idx
=
torch
.
arange
(
0
,
reqs_end
-
reqs_start
,
dtype
=
torch
.
int32
)
token_to_seq
=
torch
.
repeat_interleave
(
seq_idx
,
seq_lens_cpu
[
reqs_start
:
reqs_end
]
).
to
(
self
.
device
)
assert
total_seq_lens
<=
self
.
max_prefill_buffer_size
assert
total_seq_lens
<=
self
.
max_prefill_buffer_size
cu_seq_lens
=
(
cu_seq_lens
=
(
torch
.
cat
(
torch
.
cat
(
...
@@ -249,6 +254,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...
@@ -249,6 +254,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cu_seqlen_ks
=
cu_seqlen_ks
,
cu_seqlen_ks
=
cu_seqlen_ks
,
cu_seqlen_ke
=
cu_seqlen_ke
,
cu_seqlen_ke
=
cu_seqlen_ke
,
cu_seq_lens
=
cu_seq_lens
,
cu_seq_lens
=
cu_seq_lens
,
token_to_seq
=
token_to_seq
,
total_seq_lens
=
total_seq_lens
,
total_seq_lens
=
total_seq_lens
,
block_table
=
block_table
[
reqs_start
:
reqs_end
],
block_table
=
block_table
[
reqs_start
:
reqs_end
],
token_start
=
token_start
,
token_start
=
token_start
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
View file @
6c20e89c
...
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl
,
MLACommonBaseImpl
,
get_mla_dims
,
get_mla_dims
,
)
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionCGSupport
,
AttentionCGSupport
,
...
@@ -33,6 +34,48 @@ if TYPE_CHECKING:
...
@@ -33,6 +34,48 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
fetch_id_to_ragged_kernel
(
in_tensor_ptr
,
# [num_seq, topk]
cumsum_ptr
,
# [num_seq + 1]
out_tensor_ptr
,
# [max_num_seq * topk]
in_tensor_ptr_stride
,
TOPK
:
tl
.
constexpr
,
TOKEN_NUM
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
seq_id
=
tl
.
program_id
(
0
)
block_id
=
tl
.
program_id
(
1
)
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
token_start
=
tl
.
load
(
cumsum_ptr
+
seq_id
)
token_end
=
tl
.
load
(
cumsum_ptr
+
seq_id
+
1
)
token_num
=
token_end
-
token_start
row_offset
=
block_id
*
BLOCK_SIZE
if
row_offset
>=
token_num
:
return
in_tensor_offset
=
seq_id
*
in_tensor_ptr_stride
+
row_offset
+
offset
in_tensor_mask
=
(
row_offset
+
offset
)
<
TOPK
in_tensor_val
=
tl
.
load
(
in_tensor_ptr
+
in_tensor_offset
,
mask
=
in_tensor_mask
)
out_tensor_offset
=
token_start
+
row_offset
+
offset
out_tensor_mask
=
(
out_tensor_offset
<
token_end
)
&
in_tensor_mask
tl
.
store
(
out_tensor_ptr
+
out_tensor_offset
,
in_tensor_val
,
mask
=
out_tensor_mask
)
def
fetch_id_to_ragged_triton
(
in_tensor
:
torch
.
Tensor
,
cumsum
:
torch
.
Tensor
,
out_tensor
:
torch
.
Tensor
,
topk
):
num_tokens
=
in_tensor
.
size
(
0
)
block_size
=
64
num_block_per_row
=
triton
.
cdiv
(
topk
,
block_size
)
grid
=
(
num_tokens
,
num_block_per_row
,
)
fetch_id_to_ragged_kernel
[
grid
](
in_tensor
,
cumsum
,
out_tensor
,
in_tensor
.
stride
(
0
),
topk
,
num_tokens
,
block_size
)
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
...
@@ -83,6 +126,13 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
...
@@ -83,6 +126,13 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
block_table
:
torch
.
Tensor
block_table
:
torch
.
Tensor
req_id_per_token
:
torch
.
Tensor
req_id_per_token
:
torch
.
Tensor
qo_indptr
:
torch
.
Tensor
paged_kv_last_page_len
:
torch
.
Tensor
paged_kv_indices
:
torch
.
Tensor
paged_kv_indptr
:
torch
.
Tensor
paged_kv_indptr_rest
:
torch
.
Tensor
block_size
:
int
=
1
block_size
:
int
=
1
topk_tokens
:
int
=
2048
topk_tokens
:
int
=
2048
...
@@ -91,7 +141,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
...
@@ -91,7 +141,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
class
ROCMAiterMLASparseMetadataBuilder
(
class
ROCMAiterMLASparseMetadataBuilder
(
AttentionMetadataBuilder
[
ROCMAiterMLASparseMetadata
]
AttentionMetadataBuilder
[
ROCMAiterMLASparseMetadata
]
):
):
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
NEVER
_
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
NEVER
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -104,6 +154,7 @@ class ROCMAiterMLASparseMetadataBuilder(
...
@@ -104,6 +154,7 @@ class ROCMAiterMLASparseMetadataBuilder(
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
self
.
device
=
device
self
.
device
=
device
max_num_batched_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
...
@@ -124,6 +175,23 @@ class ROCMAiterMLASparseMetadataBuilder(
...
@@ -124,6 +175,23 @@ class ROCMAiterMLASparseMetadataBuilder(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
qo_indptr
=
torch
.
arange
(
0
,
max_num_batched_tokens
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
paged_kv_last_page_len
=
torch
.
ones
(
max_num_batched_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
# These two needs to be calculated in runtime,
# but we still needs to prepare the buffer
self
.
paged_kv_indices
=
torch
.
zeros
(
[
max_num_batched_tokens
*
self
.
topk_tokens
],
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
paged_kv_indptr
=
torch
.
zeros
(
[
max_num_batched_tokens
+
1
],
dtype
=
torch
.
int32
,
device
=
device
)
def
build
(
def
build
(
self
,
self
,
...
@@ -142,7 +210,15 @@ class ROCMAiterMLASparseMetadataBuilder(
...
@@ -142,7 +210,15 @@ class ROCMAiterMLASparseMetadataBuilder(
self
.
req_id_per_token_buffer
[:
req_id_per_token
.
shape
[
0
]].
copy_
(
self
.
req_id_per_token_buffer
[:
req_id_per_token
.
shape
[
0
]].
copy_
(
torch
.
from_numpy
(
req_id_per_token
),
non_blocking
=
True
torch
.
from_numpy
(
req_id_per_token
),
non_blocking
=
True
)
)
self
.
paged_kv_indices
.
fill_
(
0
)
self
.
paged_kv_indptr
.
fill_
(
0
)
req_id_per_token
=
self
.
req_id_per_token_buffer
[:
num_tokens
]
req_id_per_token
=
self
.
req_id_per_token_buffer
[:
num_tokens
]
qo_indptr
=
self
.
qo_indptr
[:
num_tokens
+
1
]
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_tokens
]
paged_kv_indices
=
self
.
paged_kv_indices
[:
num_tokens
*
self
.
topk_tokens
]
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
num_tokens
+
1
]
paged_kv_indptr_rest
=
self
.
paged_kv_indptr
[
num_tokens
+
1
:]
metadata
=
ROCMAiterMLASparseMetadata
(
metadata
=
ROCMAiterMLASparseMetadata
(
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
...
@@ -155,6 +231,11 @@ class ROCMAiterMLASparseMetadataBuilder(
...
@@ -155,6 +231,11 @@ class ROCMAiterMLASparseMetadataBuilder(
req_id_per_token
=
req_id_per_token
,
req_id_per_token
=
req_id_per_token
,
block_size
=
self
.
kv_cache_spec
.
block_size
,
block_size
=
self
.
kv_cache_spec
.
block_size
,
topk_tokens
=
self
.
topk_tokens
,
topk_tokens
=
self
.
topk_tokens
,
qo_indptr
=
qo_indptr
,
paged_kv_last_page_len
=
paged_kv_last_page_len
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_indptr
=
paged_kv_indptr
,
paged_kv_indptr_rest
=
paged_kv_indptr_rest
,
)
)
return
metadata
return
metadata
...
@@ -226,20 +307,39 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
...
@@ -226,20 +307,39 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
def
_forward_bf16_kv
(
def
_forward_bf16_kv
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
# [sq, heads, d_qk]
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
# [blocks, heads, d_qk]
topk_indices
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
# [sq, topk]
attn_metadata
:
ROCMAiterMLASparseMetadata
,
attn_metadata
:
ROCMAiterMLASparseMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
q
.
shape
[
0
]
num_tokens
=
q
.
shape
[
0
]
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
view
(
output
=
torch
.
empty
(
-
1
,
1
,
kv_c_and_k_pe_cache
.
shape
[
-
1
]
[
num_tokens
,
self
.
num_heads
,
self
.
kv_lora_rank
],
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
seq_len
=
(
topk_indices
!=
-
1
).
sum
(
dim
=-
1
)
torch
.
cumsum
(
seq_len
,
dim
=
0
,
out
=
attn_metadata
.
paged_kv_indptr
[
1
:])
attn_metadata
.
paged_kv_indptr_rest
.
fill_
(
attn_metadata
.
paged_kv_indptr
[
-
1
])
fetch_id_to_ragged_triton
(
topk_indices
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
topk_tokens
,
)
rocm_aiter_ops
.
mla_decode_fwd
(
q
,
kv_c_and_k_pe_cache
,
output
,
self
.
scale
,
attn_metadata
.
qo_indptr
,
1
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_len
,
)
)
topk_indices
=
topk_indices
.
view
(
num_tokens
,
1
,
-
1
)
output
=
reference_mla_sparse_prefill
(
q
,
kv_c_and_k_pe_cache
,
topk_indices
,
self
.
softmax_scale
,
512
)[
0
]
return
output
[:,
:
self
.
num_heads
,
:]
return
output
[:,
:
self
.
num_heads
,
:]
def
forward
(
def
forward
(
...
...
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
View file @
6c20e89c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
importlib
import
importlib
from
functools
import
lru_cache
import
torch
import
torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backends.mla.indexer
import
DeepseekV32IndexerMetadata
from
vllm.v1.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
logger
=
init_logger
(
__name__
)
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
@
triton
.
jit
def
fp8_mqa_logits_torch
(
def
_indexer_k_quant_and_cache_kernel
(
q
:
torch
.
Tensor
,
k_ptr
,
# [num_tokens, head_dim]
kv
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_cache_ptr
,
# [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B]
weights
:
torch
.
Tensor
,
# [n_blocks, blk_size, head_dim]
cu_seqlen_ks
:
torch
.
Tensor
,
kv_cache_scale_ptr
,
# [n_blks, blk_size]
cu_seqlen_ke
:
torch
.
Tensor
,
slot_mapping_ptr
,
# [num_tokens]
)
->
torch
.
Tensor
:
kv_cache_scale_stride
,
"""Compute FP8 MQA logits for a single sequence without KV paging.
kv_cache_value_stride
,
block_size
,
Args:
num_tokens
,
q: Query tensor of shape [M, H, D]. Casted to
head_dim
:
tl
.
constexpr
,
`torch.float8_e4m3fn` by caller.
LAYOUT
:
tl
.
constexpr
,
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
BLOCK_TILE_SIZE
:
tl
.
constexpr
,
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
HEAD_TILE_SIZE
:
tl
.
constexpr
,
[N, 1]) with dtype `torch.float32`.
IS_FNUZ
:
tl
.
constexpr
,
weights: weights of shape [M, H], dtype `torch.float32`.
USE_UE8M0
:
tl
.
constexpr
,
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
):
shape [M], dtype int32.
tid
=
tl
.
program_id
(
0
)
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
offset
=
tl
.
arange
(
0
,
head_dim
)
shape [M], dtype int32.
if
LAYOUT
==
"SHUFFLE"
:
tile_offset
=
(
offset
//
HEAD_TILE_SIZE
*
BLOCK_TILE_SIZE
*
HEAD_TILE_SIZE
+
offset
%
HEAD_TILE_SIZE
)
else
:
tile_offset
=
offset
tile_store_offset
=
tile_offset
# for idx in tl.range(tid, num_tokens, n_program):
src_ptr
=
k_ptr
+
tid
*
head_dim
slot_id
=
tl
.
load
(
slot_mapping_ptr
+
tid
)
if
slot_id
<
0
:
return
block_id
=
slot_id
//
block_size
block_offset
=
slot_id
%
block_size
tile_block_id
=
block_offset
//
BLOCK_TILE_SIZE
tile_block_offset
=
block_offset
%
BLOCK_TILE_SIZE
val
=
tl
.
load
(
src_ptr
+
offset
)
amax
=
tl
.
max
(
val
.
abs
(),
axis
=-
1
).
to
(
tl
.
float32
)
if
IS_FNUZ
:
scale
=
tl
.
maximum
(
1e-4
,
amax
)
/
224.0
else
:
scale
=
tl
.
maximum
(
1e-4
,
amax
)
/
448.0
Returns:
if
USE_UE8M0
:
Logits tensor of shape [M, N], dtype `torch.float32`.
scale
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
scale
)))
"""
k_fp8
,
scale
=
kv
seq_len_kv
=
k_fp8
.
shape
[
0
]
k
=
k_fp8
.
to
(
torch
.
bfloat16
)
q
=
q
.
to
(
torch
.
bfloat16
)
mask_lo
=
(
fp8_val
=
(
val
.
to
(
tl
.
float32
)
/
scale
).
to
(
kv_cache_ptr
.
type
.
element_ty
)
torch
.
arange
(
0
,
seq_len_kv
,
device
=
"cuda"
)[
None
,
:]
>=
cu_seqlen_ks
[:,
None
]
if
LAYOUT
==
"SHUFFLE"
:
dst_ptr
=
(
kv_cache_ptr
+
block_id
*
kv_cache_value_stride
+
tile_block_id
*
BLOCK_TILE_SIZE
*
head_dim
+
tile_block_offset
*
HEAD_TILE_SIZE
)
)
mask_hi
=
(
else
:
torch
.
arange
(
0
,
seq_len_kv
,
device
=
"cuda"
)[
None
,
:]
<
cu_seqlen_ke
[:,
None
]
dst_ptr
=
(
kv_cache_ptr
+
block_id
*
kv_cache_value_stride
+
block_offset
*
head_dim
)
)
mask
=
mask_lo
&
mask_hi
tl
.
store
(
dst_ptr
+
tile_store_offset
,
fp8_val
)
dst_scale_ptr
=
kv_cache_scale_ptr
+
block_id
*
kv_cache_scale_stride
+
block_offset
score
=
torch
.
einsum
(
"mhd,nd->hmn"
,
q
,
k
).
float
()
*
scale
tl
.
store
(
dst_scale_ptr
,
scale
)
logits
=
(
score
.
relu
()
*
weights
.
unsqueeze
(
-
1
).
transpose
(
0
,
1
)).
sum
(
dim
=
0
)
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
"-inf"
))
return
logits
def
rocm_fp8_mqa_logits
(
def
indexer_k_quant_and_cache_triton
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
kv
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
# [num_blocks, block_size, head_dim + 4]
weights
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
quant_block_size
,
cu_seqlen_ke
:
torch
.
Tensor
,
scale_fmt
,
)
->
torch
.
Tensor
:
block_tile_size
=
16
,
"""Compute FP8 MQA logits for a single sequence without KV paging.
head_tile_size
=
16
,
):
Args:
num_blocks
=
kv_cache
.
shape
[
0
]
q: Query tensor of shape [M, H, D]. Casted to
head_dim
=
k
.
shape
[
-
1
]
`torch.float8_e4m3fn` by caller.
num_tokens
=
slot_mapping
.
shape
[
0
]
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
block_size
=
kv_cache
.
shape
[
1
]
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
# In real layout, we store the first portion as kv cache value
[N, 1]) with dtype `torch.float32`.
# and second portion as kv cache scale
weights: weights of shape [M, H], dtype `torch.float32`.
kv_cache
=
kv_cache
.
view
(
num_blocks
,
-
1
)
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
kv_cache_value
=
kv_cache
[:,
:
block_size
*
head_dim
]
shape [M], dtype int32.
kv_cache_scale
=
kv_cache
[:,
block_size
*
head_dim
:].
view
(
torch
.
float32
)
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
head_tile_size
=
head_tile_size
//
kv_cache
.
element_size
()
shape [M], dtype int32.
grid
=
(
num_tokens
,)
_indexer_k_quant_and_cache_kernel
[
grid
](
k
,
kv_cache_value
,
kv_cache_scale
,
slot_mapping
,
kv_cache_scale
.
stride
(
0
),
kv_cache_value
.
stride
(
0
),
block_size
,
num_tokens
,
head_dim
,
"NHD"
,
block_tile_size
,
head_tile_size
,
IS_FNUZ
=
current_platform
.
fp8_dtype
()
==
torch
.
float8_e4m3fnuz
,
USE_UE8M0
=
scale_fmt
==
"ue8m0"
,
)
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
@
triton
.
jit
# path after aiter merge this kernel into main
def
_cp_gather_indexer_quant_cache_kernel
(
@
lru_cache
kv_cache_ptr
,
# [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B]
def
has_mqa_logits_module
():
# [n_blks, blk_size, head_dim]
return
importlib
.
util
.
find_spec
(
"aiter.ops.triton.fp8_mqa_logits"
)
is
not
None
kv_cache_scale_ptr
,
# [n_blks, blk_size]
k_fp8_ptr
,
# [num_tokens, head_dim]
k_scale_ptr
,
# [num_tokens]
block_table_ptr
,
# [batch_size, block_table_stride]
cu_seqlen_ptr
,
# [batch_size + 1]
token_to_seq_ptr
,
# [num_tokens]
block_size
,
block_table_stride
,
kv_cache_stride
,
kv_cache_scale_stride
,
LAYOUT
:
tl
.
constexpr
,
HEAD_DIM
:
tl
.
constexpr
,
BLOCK_TILE_SIZE
:
tl
.
constexpr
,
HEAD_TILE_SIZE
:
tl
.
constexpr
,
):
tid
=
tl
.
program_id
(
0
)
offset
=
tl
.
arange
(
0
,
HEAD_DIM
)
batch_id
=
tl
.
load
(
token_to_seq_ptr
+
tid
)
batch_start
=
tl
.
load
(
cu_seqlen_ptr
+
batch_id
)
batch_end
=
tl
.
load
(
cu_seqlen_ptr
+
batch_id
+
1
)
batch_offset
=
tid
-
batch_start
if
tid
>=
batch_end
:
return
block_table_id
=
batch_offset
//
block_size
block_offset
=
batch_offset
%
block_size
block_table_offset
=
batch_id
*
block_table_stride
+
block_table_id
block_id
=
tl
.
load
(
block_table_ptr
+
block_table_offset
)
tiled_block_id
=
block_offset
//
BLOCK_TILE_SIZE
tiled_block_offset
=
block_offset
%
BLOCK_TILE_SIZE
if
LAYOUT
==
"SHUFFLE"
:
src_cache_offset
=
(
block_id
*
kv_cache_stride
+
tiled_block_id
*
HEAD_DIM
*
BLOCK_TILE_SIZE
+
tiled_block_offset
*
HEAD_TILE_SIZE
)
else
:
src_cache_offset
=
block_id
*
kv_cache_stride
+
block_offset
*
HEAD_DIM
src_scale_offset
=
block_id
*
kv_cache_scale_stride
+
block_offset
dst_offset
=
tid
*
HEAD_DIM
src_scale_ptr
=
kv_cache_scale_ptr
+
src_scale_offset
src_cache_ptr
=
kv_cache_ptr
+
src_cache_offset
dst_k_ptr
=
k_fp8_ptr
+
dst_offset
scale_val
=
tl
.
load
(
src_scale_ptr
)
tl
.
store
(
k_scale_ptr
+
tid
,
scale_val
)
if
LAYOUT
==
"SHUFFLE"
:
tiled_src_offset
=
(
offset
//
HEAD_TILE_SIZE
*
HEAD_TILE_SIZE
*
BLOCK_TILE_SIZE
+
offset
%
HEAD_TILE_SIZE
)
else
:
tiled_src_offset
=
offset
val
=
tl
.
load
(
src_cache_ptr
+
tiled_src_offset
)
tl
.
store
(
dst_k_ptr
+
offset
,
val
)
if
rocm_aiter_ops
.
is_enabled
()
and
has_mqa_logits_module
():
from
aiter.ops.triton.fp8_mqa_logits
import
fp8_mqa_logits
kv
,
scale
=
kv
def
cp_gather_indexer_k_quant_cache_triton
(
return
fp8_mqa_logits
(
q
,
kv
,
scale
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
k_cache
:
torch
.
Tensor
,
# [num_blocks, block_size, head_dim + 4]
else
:
k_fp8
:
torch
.
Tensor
,
return
fp8_mqa_logits_torch
(
q
,
kv
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
k_fp8_scale
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cu_seqlen
:
torch
.
Tensor
,
token_to_seq
:
torch
.
Tensor
,
block_tile_size
:
int
=
16
,
head_tile_size
:
int
=
16
,
):
num_tokens
=
k_fp8
.
size
(
0
)
block_size
=
k_cache
.
size
(
1
)
block_table_stride
=
block_table
.
stride
(
0
)
head_dim
=
k_fp8
.
shape
[
-
1
]
num_blocks
=
k_cache
.
shape
[
0
]
# we assume the kv cache already been split to 2 portion
k_cache
=
k_cache
.
view
(
num_blocks
,
-
1
)
fp8_dtype
=
current_platform
.
fp8_dtype
()
k_cache_value
=
k_cache
[:,
:
block_size
*
head_dim
].
view
(
fp8_dtype
)
k_cache_scale
=
k_cache
[:,
block_size
*
head_dim
:].
view
(
torch
.
float32
)
grid
=
(
num_tokens
,)
k_fp8_scale
=
k_fp8_scale
.
view
(
torch
.
float32
)
_cp_gather_indexer_quant_cache_kernel
[
grid
](
k_cache_value
,
k_cache_scale
,
k_fp8
,
k_fp8_scale
,
block_table
,
cu_seqlen
,
token_to_seq
,
block_size
,
block_table_stride
,
k_cache_value
.
stride
(
0
),
k_cache_scale
.
stride
(
0
),
"NHD"
,
head_dim
,
block_tile_size
,
head_tile_size
,
)
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
...
@@ -183,10 +303,38 @@ def rocm_fp8_paged_mqa_logits(
...
@@ -183,10 +303,38 @@ def rocm_fp8_paged_mqa_logits(
Logits tensor of shape [B * next_n, max_model_len], dtype
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
`torch.float32`.
"""
"""
from
vllm._aiter_ops
import
rocm_aiter_ops
@
functools
.
lru_cache
def
paged_mqa_logits_module
():
paged_mqa_logits_module_path
=
None
if
importlib
.
util
.
find_spec
(
"aiter.ops.triton.pa_mqa_logits"
)
is
not
None
:
paged_mqa_logits_module_path
=
"aiter.ops.triton.pa_mqa_logits"
elif
(
importlib
.
util
.
find_spec
(
"aiter.ops.triton.attention.pa_mqa_logits"
)
is
not
None
):
paged_mqa_logits_module_path
=
"aiter.ops.triton.attention.pa_mqa_logits"
if
paged_mqa_logits_module_path
is
not
None
:
try
:
module
=
importlib
.
import_module
(
paged_mqa_logits_module_path
)
return
module
except
ImportError
:
return
None
return
None
aiter_paged_mqa_logits_module
=
None
if
rocm_aiter_ops
.
is_enabled
():
if
rocm_aiter_ops
.
is_enabled
():
from
aiter.ops.triton.pa_mqa_logits
import
deepgemm_fp8_paged_mqa_logits_stage1
aiter_paged_mqa_logits_module
=
paged_mqa_logits_module
()
# FIXME(ganyi): Temporarily disable the aiter path until nightly docker
# update aiter to the fix PR.
aiter_paged_mqa_logits_module
=
None
if
aiter_paged_mqa_logits_module
is
not
None
:
deepgemm_fp8_paged_mqa_logits_stage1
=
(
aiter_paged_mqa_logits_module
.
deepgemm_fp8_paged_mqa_logits_stage1
)
batch_size
,
next_n
,
heads
,
_
=
q_fp8
.
shape
batch_size
,
next_n
,
heads
,
_
=
q_fp8
.
shape
out_qk
=
torch
.
full
(
out_qk
=
torch
.
full
(
(
heads
,
batch_size
*
next_n
,
max_model_len
),
(
heads
,
batch_size
*
next_n
,
max_model_len
),
...
@@ -208,3 +356,293 @@ def rocm_fp8_paged_mqa_logits(
...
@@ -208,3 +356,293 @@ def rocm_fp8_paged_mqa_logits(
return
fp8_paged_mqa_logits_torch
(
return
fp8_paged_mqa_logits_torch
(
q_fp8
,
kv_cache_fp8
,
weights
,
context_lens
,
block_tables
,
max_model_len
q_fp8
,
kv_cache_fp8
,
weights
,
context_lens
,
block_tables
,
max_model_len
)
)
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
def
fp8_mqa_logits_torch
(
q
:
torch
.
Tensor
,
kv
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
weights
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv
,
scale
=
kv
seq_len_kv
=
kv
.
shape
[
0
]
k
=
kv
.
to
(
torch
.
bfloat16
)
q
=
q
.
to
(
torch
.
bfloat16
)
mask_lo
=
(
torch
.
arange
(
0
,
seq_len_kv
,
device
=
"cuda"
)[
None
,
:]
>=
cu_seqlen_ks
[:,
None
]
)
mask_hi
=
(
torch
.
arange
(
0
,
seq_len_kv
,
device
=
"cuda"
)[
None
,
:]
<
cu_seqlen_ke
[:,
None
]
)
mask
=
mask_lo
&
mask_hi
score
=
torch
.
einsum
(
"mhd,nd->hmn"
,
q
,
k
).
float
()
*
scale
logits
=
(
score
.
relu
()
*
weights
.
unsqueeze
(
-
1
).
transpose
(
0
,
1
)).
sum
(
dim
=
0
)
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
"-inf"
))
return
logits
def
rocm_fp8_mqa_logits
(
q
:
torch
.
Tensor
,
kv
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
weights
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
# path after aiter merge this kernel into main
from
vllm._aiter_ops
import
rocm_aiter_ops
@
functools
.
lru_cache
def
mqa_logits_module
():
mqa_logits_module_path
=
None
if
importlib
.
util
.
find_spec
(
"aiter.ops.triton.fp8_mqa_logits"
)
is
not
None
:
mqa_logits_module_path
=
"aiter.ops.triton.fp8_mqa_logits"
elif
(
importlib
.
util
.
find_spec
(
"aiter.ops.triton.attention.fp8_mqa_logits"
)
is
not
None
):
mqa_logits_module_path
=
"aiter.ops.triton.attention.fp8_mqa_logits"
if
mqa_logits_module_path
is
not
None
:
try
:
module
=
importlib
.
import_module
(
mqa_logits_module_path
)
return
module
except
ImportError
:
return
None
return
None
aiter_mqa_logits_module
=
None
if
rocm_aiter_ops
.
is_enabled
():
aiter_mqa_logits_module
=
mqa_logits_module
()
if
aiter_mqa_logits_module
is
not
None
:
fp8_mqa_logits
=
aiter_mqa_logits_module
.
fp8_mqa_logits
kv
,
scale
=
kv
return
fp8_mqa_logits
(
q
,
kv
,
scale
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
else
:
return
fp8_mqa_logits_torch
(
q
,
kv
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
def
rocm_aiter_sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv
=
torch
.
empty
(
[
total_seq_lens
,
head_dim
+
4
],
device
=
k
.
device
,
dtype
=
torch
.
uint8
)
fp8_dtype
=
current_platform
.
fp8_dtype
()
_k_fp8
=
_flattened_kv
[...,
:
head_dim
].
view
(
fp8_dtype
).
contiguous
()
_k_scale
=
_flattened_kv
[...,
head_dim
:].
view
(
torch
.
float32
).
contiguous
()
return
topk_indices_buffer
def
rocm_aiter_sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
topk_tokens
:
int
,
head_dim
:
int
,
max_model_len
:
int
,
total_seq_lens
:
int
,
topk_indices_buffer
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
return
rocm_aiter_sparse_attn_indexer_fake
(
hidden_states
,
k_cache_prefix
,
kv_cache
,
q_fp8
,
k
,
weights
,
quant_block_size
,
scale_fmt
,
topk_tokens
,
head_dim
,
max_model_len
,
total_seq_lens
,
topk_indices_buffer
,
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
scale_fmt
,
)
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
torch
.
empty
(
[
chunk
.
total_seq_lens
,
head_dim
],
device
=
k
.
device
,
dtype
=
fp8_dtype
,
)
k_scale
=
torch
.
empty
(
[
chunk
.
total_seq_lens
,
4
],
device
=
k
.
device
,
dtype
=
torch
.
uint8
,
)
ops
.
cp_gather_indexer_k_quant_cache
(
kv_cache
,
k_fp8
,
k_scale
,
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
logits
=
rocm_fp8_mqa_logits
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
)),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
)
num_rows
=
logits
.
shape
[
0
]
assert
topk_tokens
==
2048
,
"top_k_per_row assumes size 2048"
topk_indices
=
topk_indices_buffer
[
chunk
.
token_start
:
chunk
.
token_end
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_prefill
(
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache
=
kv_cache
.
unsqueeze
(
-
2
)
decode_lens
=
decode_metadata
.
decode_lens
if
decode_metadata
.
requires_padding
:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens
=
pack_seq_triton
(
q_fp8
[:
num_decode_tokens
],
decode_lens
)
else
:
padded_q_fp8_decode_tokens
=
q_fp8
[:
num_decode_tokens
].
reshape
(
decode_lens
.
shape
[
0
],
-
1
,
*
q_fp8
.
shape
[
1
:]
)
# TODO: move and optimize below logic with triton kernels
batch_size
=
padded_q_fp8_decode_tokens
.
shape
[
0
]
next_n
=
padded_q_fp8_decode_tokens
.
shape
[
1
]
assert
batch_size
==
decode_metadata
.
seq_lens
.
shape
[
0
]
num_padded_tokens
=
batch_size
*
next_n
logits
=
rocm_fp8_paged_mqa_logits
(
padded_q_fp8_decode_tokens
,
kv_cache
,
weights
[:
num_padded_tokens
],
decode_metadata
.
seq_lens
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
)
num_rows
=
logits
.
shape
[
0
]
assert
topk_tokens
==
2048
,
"top_k_per_row assumes size 2048"
topk_indices
=
topk_indices_buffer
[:
num_decode_tokens
,
:
topk_tokens
]
torch
.
ops
.
_C
.
top_k_per_row_decode
(
logits
,
next_n
,
decode_metadata
.
seq_lens
,
topk_indices
,
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
topk_tokens
,
)
if
decode_metadata
.
requires_padding
:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices
=
unpack_seq_triton
(
topk_indices
.
reshape
(
batch_size
,
-
1
,
topk_indices
.
shape
[
-
1
]),
decode_lens
,
)
topk_indices_buffer
[:
num_decode_tokens
,
:
topk_indices
.
shape
[
-
1
]]
=
(
topk_indices
)
return
topk_indices_buffer
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment