Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
06c20c99
Unverified
Commit
06c20c99
authored
Nov 20, 2025
by
Pleaplusone
Committed by
GitHub
Nov 20, 2025
Browse files
[ROCm] Add AMD GPU support on Deepseek v3.2 and SparseMLA (#26670)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
6eb745d9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
583 additions
and
15 deletions
+583
-15
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+4
-0
vllm/attention/ops/rocm_aiter_mla_sparse.py
vllm/attention/ops/rocm_aiter_mla_sparse.py
+210
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+18
-4
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+12
-1
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+3
-2
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+1
-1
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+9
-6
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+325
-0
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+1
-1
No files found.
csrc/cache_kernels.cu
View file @
06c20c99
...
...
@@ -552,7 +552,11 @@ __global__ void indexer_k_quant_and_cache_kernel(
#ifndef USE_ROCM
__syncwarp
();
#endif
#if defined(__gfx942__)
float
scale
=
fmaxf
(
amax
,
1e-4
)
/
224.0
f
;
#else
float
scale
=
fmaxf
(
amax
,
1e-4
)
/
448.0
f
;
#endif
if
(
use_ue8m0
)
{
scale
=
exp2f
(
ceilf
(
log2f
(
scale
)));
}
...
...
vllm/attention/ops/rocm_aiter_mla_sparse.py
0 → 100644
View file @
06c20c99
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
from
functools
import
lru_cache
import
torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
def
fp8_mqa_logits_torch
(
q
:
torch
.
Tensor
,
kv
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
weights
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv
,
scale
=
kv
seq_len_kv
=
kv
.
shape
[
0
]
k
=
kv
.
to
(
torch
.
bfloat16
)
q
=
q
.
to
(
torch
.
bfloat16
)
mask_lo
=
(
torch
.
arange
(
0
,
seq_len_kv
,
device
=
"cuda"
)[
None
,
:]
>=
cu_seqlen_ks
[:,
None
]
)
mask_hi
=
(
torch
.
arange
(
0
,
seq_len_kv
,
device
=
"cuda"
)[
None
,
:]
<
cu_seqlen_ke
[:,
None
]
)
mask
=
mask_lo
&
mask_hi
score
=
torch
.
einsum
(
"mhd,nd->hmn"
,
q
,
k
).
float
()
*
scale
logits
=
(
score
.
relu
()
*
weights
.
unsqueeze
(
-
1
).
transpose
(
0
,
1
)).
sum
(
dim
=
0
)
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
"-inf"
))
return
logits
def
rocm_fp8_mqa_logits
(
q
:
torch
.
Tensor
,
kv
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
weights
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
# path after aiter merge this kernel into main
@
lru_cache
def
has_mqa_logits_module
():
return
importlib
.
util
.
find_spec
(
"aiter.ops.triton.fp8_mqa_logits"
)
is
not
None
if
rocm_aiter_ops
.
is_enabled
()
and
has_mqa_logits_module
():
from
aiter.ops.triton.fp8_mqa_logits
import
fp8_mqa_logits
kv
,
scale
=
kv
return
fp8_mqa_logits
(
q
,
kv
,
scale
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
else
:
return
fp8_mqa_logits_torch
(
q
,
kv
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
def
fp8_paged_mqa_logits_torch
(
q
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
max_model_len
:
int
,
):
from
vllm.utils.math_utils
import
cdiv
fp8_dtype
=
current_platform
.
fp8_dtype
()
batch_size
,
next_n
,
_
,
dim
=
q
.
size
()
kv_cache
,
scale
=
kv_cache
[...,
:
dim
],
kv_cache
[...,
dim
:]
scale
=
scale
.
contiguous
().
view
(
torch
.
float
)
q
=
q
.
float
()
kv_cache
=
kv_cache
.
view
(
fp8_dtype
).
float
()
*
scale
num_block
,
block_size
,
_
,
dim
=
kv_cache
.
size
()
logits
=
torch
.
full
(
[
batch_size
*
next_n
,
max_model_len
],
float
(
"-inf"
),
device
=
q
.
device
,
dtype
=
torch
.
float32
,
)
context_lens
=
context_lens
.
tolist
()
for
i
in
range
(
batch_size
):
context_len
=
context_lens
[
i
]
q_offsets
=
torch
.
arange
(
context_len
-
next_n
,
context_len
,
device
=
"cuda"
)
weight_slice
=
(
weights
[
i
*
next_n
:
(
i
+
1
)
*
next_n
,
:].
transpose
(
0
,
1
).
contiguous
()
)
for
block_rk
in
range
(
cdiv
(
context_len
,
block_size
)):
block_idx
=
block_tables
[
i
][
block_rk
]
qx
,
kx
=
q
[
i
],
kv_cache
[
block_idx
]
k_offsets
=
torch
.
arange
(
block_rk
*
block_size
,
(
block_rk
+
1
)
*
block_size
,
device
=
"cuda"
)
mask
=
(
k_offsets
[
None
,
:]
<
context_len
)
&
(
k_offsets
[
None
,
:]
<=
q_offsets
[:,
None
]
)
s
=
torch
.
where
(
mask
[
None
,
:,
:],
(
qx
.
transpose
(
0
,
1
)
@
kx
.
transpose
(
0
,
1
).
transpose
(
1
,
2
)).
to
(
logits
.
dtype
),
float
(
"-inf"
),
)
s
=
torch
.
relu
(
s
)
*
weight_slice
[...,
None
]
s
=
s
.
sum
(
dim
=
0
)
logits
[
i
*
next_n
:
(
i
+
1
)
*
next_n
,
block_rk
*
block_size
:
(
block_rk
+
1
)
*
block_size
,
]
=
torch
.
where
(
k_offsets
[
None
,
:]
<=
q_offsets
[:,
None
],
s
,
float
(
"-inf"
))
return
logits
def
rocm_fp8_paged_mqa_logits
(
q_fp8
:
torch
.
Tensor
,
kv_cache_fp8
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
schedule_metadata
:
torch
.
Tensor
,
max_model_len
:
int
,
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits using paged KV-cache.
Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
if
rocm_aiter_ops
.
is_enabled
():
from
aiter.ops.triton.pa_mqa_logits
import
deepgemm_fp8_paged_mqa_logits_stage1
batch_size
,
next_n
,
heads
,
_
=
q_fp8
.
shape
out_qk
=
torch
.
full
(
(
heads
,
batch_size
*
next_n
,
max_model_len
),
float
(
"-inf"
),
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
deepgemm_fp8_paged_mqa_logits_stage1
(
q_fp8
,
kv_cache_fp8
,
weights
,
out_qk
,
context_lens
,
block_tables
,
max_model_len
,
)
return
out_qk
.
sum
(
dim
=
0
)
else
:
return
fp8_paged_mqa_logits_torch
(
q_fp8
,
kv_cache_fp8
,
weights
,
context_lens
,
block_tables
,
max_model_len
)
vllm/model_executor/models/deepseek_v2.py
View file @
06c20c99
...
...
@@ -594,6 +594,7 @@ def sparse_attn_indexer(
)
->
torch
.
Tensor
:
# careful! this will be None in dummy run
attn_metadata
=
get_forward_context
().
attn_metadata
fp8_dtype
=
current_platform
.
fp8_dtype
()
# assert isinstance(attn_metadata, dict)
if
not
isinstance
(
attn_metadata
,
dict
):
return
sparse_attn_indexer_fake
(
...
...
@@ -633,7 +634,7 @@ def sparse_attn_indexer(
k_fp8
=
torch
.
empty
(
[
chunk
.
total_seq_lens
,
head_dim
],
device
=
k
.
device
,
dtype
=
torch
.
float8_e4m3fn
,
dtype
=
fp8_dtype
,
)
k_scale
=
torch
.
empty
(
[
chunk
.
total_seq_lens
,
4
],
...
...
@@ -647,7 +648,12 @@ def sparse_attn_indexer(
chunk
.
block_table
,
chunk
.
cu_seq_lens
,
)
logits
=
fp8_mqa_logits
(
fp8_mqa_logits_func
=
fp8_mqa_logits
if
current_platform
.
is_rocm
():
from
vllm.attention.ops.rocm_aiter_mla_sparse
import
rocm_fp8_mqa_logits
fp8_mqa_logits_func
=
rocm_fp8_mqa_logits
logits
=
fp8_mqa_logits_func
(
q_fp8
[
chunk
.
token_start
:
chunk
.
token_end
],
(
k_fp8
,
k_scale
.
view
(
torch
.
float32
)),
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
...
...
@@ -692,7 +698,14 @@ def sparse_attn_indexer(
next_n
=
padded_q_fp8_decode_tokens
.
shape
[
1
]
assert
batch_size
==
decode_metadata
.
seq_lens
.
shape
[
0
]
num_padded_tokens
=
batch_size
*
next_n
logits
=
fp8_paged_mqa_logits
(
fp8_paged_mqa_logits_func
=
fp8_paged_mqa_logits
if
current_platform
.
is_rocm
():
from
vllm.attention.ops.rocm_aiter_mla_sparse
import
(
rocm_fp8_paged_mqa_logits
,
)
fp8_paged_mqa_logits_func
=
rocm_fp8_paged_mqa_logits
logits
=
fp8_paged_mqa_logits_func
(
padded_q_fp8_decode_tokens
,
kv_cache
,
weights
[:
num_padded_tokens
],
...
...
@@ -749,7 +762,8 @@ def sparse_attn_indexer_fake(
_flattened_kv
=
torch
.
empty
(
[
total_seq_lens
,
head_dim
+
4
],
device
=
k
.
device
,
dtype
=
torch
.
uint8
)
_k_fp8
=
_flattened_kv
[...,
:
head_dim
].
view
(
torch
.
float8_e4m3fn
).
contiguous
()
fp8_dtype
=
current_platform
.
fp8_dtype
()
_k_fp8
=
_flattened_kv
[...,
:
head_dim
].
view
(
fp8_dtype
).
contiguous
()
_k_scale
=
_flattened_kv
[...,
head_dim
:].
view
(
torch
.
float32
).
contiguous
()
return
topk_indices_buffer
...
...
vllm/platforms/rocm.py
View file @
06c20c99
...
...
@@ -225,7 +225,18 @@ class RocmPlatform(Platform):
from
vllm.attention.backends.registry
import
AttentionBackendEnum
if
use_sparse
:
raise
NotImplementedError
(
"Sparse Attention is not supported on ROCm."
)
if
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
ValueError
(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
assert
block_size
==
1
,
(
"Sparse MLA backend on ROCm only supports block size 1 for now."
)
logger
.
info_once
(
"Using Sparse MLA backend on V1 engine."
)
return
(
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse."
"ROCMAiterMLASparseBackend"
)
if
use_mla
:
if
selected_backend
is
None
:
...
...
vllm/utils/deep_gemm.py
View file @
06c20c99
...
...
@@ -325,6 +325,7 @@ DEFAULT_BLOCK_SIZE = [128, 128]
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size
:
list
[
int
]
=
DEFAULT_BLOCK_SIZE
,
use_ue8m0
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
fp8_dtype
=
current_platform
.
fp8_dtype
()
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
block_m
,
block_n
=
block_size
...
...
@@ -334,9 +335,9 @@ def per_block_cast_to_fp8(
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
block_m
,
x_padded
.
size
(
1
)
//
block_n
,
block_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
sf
=
x_amax
/
448.0
sf
=
x_amax
/
224.0
if
current_platform
.
is_fp8_fnuz
()
else
x_amax
/
448.0
sf
=
_ceil_to_ue8m0
(
sf
)
if
use_ue8m0
else
sf
x_scaled
=
(
x_view
*
(
1.0
/
sf
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled
=
(
x_view
*
(
1.0
/
sf
)).
to
(
fp8_dtype
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
sf
.
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
)
)
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
06c20c99
...
...
@@ -168,7 +168,7 @@ def _convert_req_index_to_global_index_kernel(
inblock_off
=
tok
%
BLOCK_SIZE
# Guard block_table access
valid_block
=
block_id
<
max_num_blocks_per_req
valid_block
=
(
block_id
<
max_num_blocks_per_req
)
&
(
block_id
>=
0
)
bt_ptr
=
block_table_ptr
+
req
*
bt_stride0
+
block_id
*
bt_stride1
base
=
tl
.
load
(
bt_ptr
,
mask
=
valid_block
,
other
=
0
)
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
06c20c99
...
...
@@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (
)
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils.deep_gemm
import
get_paged_mqa_logits_metadata
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
get_paged_mqa_logits_metadata
,
is_deep_gemm_supported
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
...
...
@@ -23,7 +24,9 @@ logger = init_logger(__name__)
class
DeepseekV32IndexerBackend
(
AttentionBackend
):
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
64
]
supported_kernel_block_sizes
:
ClassVar
[
list
[
int
|
MultipleOf
]]
=
[
1
if
current_platform
.
is_rocm
()
else
64
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
...
...
@@ -328,10 +331,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding
=
(
decode_lens_cpu
.
max
()
>
decode_lens_cpu
.
min
()).
item
()
seq_lens
=
common_attn_metadata
.
seq_lens
[:
num_decodes
]
self
.
scheduler_metadata_buffer
[:]
=
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
)
if
is_deep_gemm_supported
():
self
.
scheduler_metadata_buffer
[:]
=
get_paged_mqa_logits_metadata
(
seq_lens
,
self
.
kv_cache_spec
.
block_size
,
self
.
num_sms
)
decode_metadata
=
DeepSeekV32IndexerDecodeMetadata
(
block_table
=
common_attn_metadata
.
block_table_tensor
[:
num_decodes
,
...],
seq_lens
=
common_attn_metadata
.
seq_lens
[:
num_decodes
],
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
0 → 100644
View file @
06c20c99
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Optional
import
numpy
as
np
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
)
from
vllm.attention.backends.utils
import
get_mla_dims
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBaseImpl
,
)
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
triton_convert_req_index_to_global_index
,
)
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
if
TYPE_CHECKING
:
from
vllm.model_executor.models.deepseek_v2
import
Indexer
logger
=
init_logger
(
__name__
)
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
def
get_name
()
->
str
:
return
"ROCM_AITER_MLA_SPARSE"
@
staticmethod
def
get_metadata_cls
()
->
type
[
AttentionMetadata
]:
return
ROCMAiterMLASparseMetadata
@
staticmethod
def
get_builder_cls
()
->
type
[
"ROCMAiterMLASparseMetadataBuilder"
]:
return
ROCMAiterMLASparseMetadataBuilder
@
staticmethod
def
get_impl_cls
()
->
type
[
"ROCMAiterMLASparseImpl"
]:
return
ROCMAiterMLASparseImpl
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
# assumed to be 1 for MLA
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
576
]
@
dataclass
class
ROCMAiterMLASparseMetadata
:
num_reqs
:
int
max_query_len
:
int
max_seq_len
:
int
num_actual_tokens
:
int
# Number of tokens excluding padding.
query_start_loc
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
block_table
:
torch
.
Tensor
req_id_per_token
:
torch
.
Tensor
block_size
:
int
=
1
topk_tokens
:
int
=
2048
@
dataclass
class
ROCMAiterMLASparseMetadataBuilder
(
AttentionMetadataBuilder
[
ROCMAiterMLASparseMetadata
]
):
cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
NEVER
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
self
.
kv_cache_spec
=
kv_cache_spec
self
.
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
self
.
device
=
device
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
self
.
topk_tokens
=
vllm_config
.
model_config
.
hf_config
.
index_topk
self
.
topk_tokens_tensor
=
torch
.
tensor
(
[
self
.
topk_tokens
],
device
=
device
,
dtype
=
torch
.
int32
)
self
.
max_model_len_tensor
=
torch
.
tensor
(
[
self
.
model_config
.
max_model_len
],
device
=
device
,
dtype
=
torch
.
int32
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self
.
dummy_block_table
=
torch
.
empty
(
(
1
,
1
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_id_per_token_buffer
=
torch
.
empty
(
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
ROCMAiterMLASparseMetadata
:
num_tokens
=
common_attn_metadata
.
num_actual_tokens
starts
=
np
.
asarray
(
common_attn_metadata
.
query_start_loc_cpu
,
dtype
=
np
.
int32
)
seg_lengths
=
np
.
diff
(
starts
)
req_id_per_token
=
np
.
repeat
(
np
.
arange
(
seg_lengths
.
shape
[
0
],
dtype
=
np
.
int32
),
seg_lengths
)
# Zero-fill for cudagraphs
self
.
req_id_per_token_buffer
.
fill_
(
0
)
self
.
req_id_per_token_buffer
[:
req_id_per_token
.
shape
[
0
]].
copy_
(
torch
.
from_numpy
(
req_id_per_token
),
non_blocking
=
True
)
req_id_per_token
=
self
.
req_id_per_token_buffer
[:
num_tokens
]
metadata
=
ROCMAiterMLASparseMetadata
(
num_reqs
=
common_attn_metadata
.
num_reqs
,
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
block_table
=
common_attn_metadata
.
block_table_tensor
,
req_id_per_token
=
req_id_per_token
,
block_size
=
self
.
kv_cache_spec
.
block_size
,
topk_tokens
=
self
.
topk_tokens
,
)
return
metadata
# Take from
# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_prefill.py#L72
def
reference_mla_sparse_prefill
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
sm_scale
:
float
,
d_v
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
import
math
def
log2sumexp2
(
a
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
return
torch
.
logsumexp
(
a
*
math
.
log
(
2
),
dim
=
dim
)
*
math
.
log2
(
math
.
e
)
skv
=
kv
.
shape
[
0
]
sq
=
q
.
shape
[
0
]
topk
=
indices
.
shape
[
-
1
]
dqk
=
q
.
shape
[
-
1
]
indices
=
indices
[:,
0
,
:]
# [s_q, topk]
invalid_indices_mask
=
(
indices
<
0
)
|
(
indices
>=
skv
)
indices
[
invalid_indices_mask
]
=
0
qs
=
q
# [s_q, h_q, d_qk]
kvs
=
kv
[:,
0
,
:][
indices
].
view
(
sq
,
topk
,
dqk
)
# [s_q, topk, d_qk]
attn_score
=
(
qs
@
kvs
.
transpose
(
1
,
2
)).
float
()
# [s_q, h_q, topk]
attn_score
.
masked_fill_
(
invalid_indices_mask
.
unsqueeze
(
1
),
float
(
"-inf"
))
attn_score
*=
sm_scale
*
math
.
log2
(
math
.
e
)
lse
=
log2sumexp2
(
attn_score
,
dim
=-
1
)
# [s_q, h_q]
attn_score
=
torch
.
exp2
(
attn_score
-
lse
.
unsqueeze
(
-
1
))
# [s_q, h_q, topk]
result
=
attn_score
.
to
(
q
.
dtype
)
@
kvs
[:,
:,
:
d_v
]
return
(
result
,
lse
)
class
ROCMAiterMLASparseImpl
(
MLACommonBaseImpl
[
ROCMAiterMLASparseMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
list
[
float
]
|
None
,
sliding_window
:
int
|
None
,
kv_cache_dtype
:
str
,
logits_soft_cap
:
float
|
None
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
str
|
None
,
# MLA Specific Arguments
topk_indice_buffer
:
torch
.
Tensor
|
None
=
None
,
indexer
:
Optional
[
"Indexer"
]
=
None
,
**
mla_args
,
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
,
)
self
.
softmax_scale
=
scale
assert
indexer
is
not
None
self
.
topk_indices_buffer
=
indexer
.
topk_indices_buffer
self
.
is_fp8bmm_enabled
=
rocm_aiter_ops
.
is_fp8bmm_enabled
()
def
_forward_bf16_kv
(
self
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
attn_metadata
:
ROCMAiterMLASparseMetadata
,
)
->
torch
.
Tensor
:
num_tokens
=
q
.
shape
[
0
]
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
view
(
-
1
,
1
,
kv_c_and_k_pe_cache
.
shape
[
-
1
]
)
topk_indices
=
topk_indices
.
view
(
num_tokens
,
1
,
-
1
)
output
=
reference_mla_sparse_prefill
(
q
,
kv_c_and_k_pe_cache
,
topk_indices
,
self
.
softmax_scale
,
512
)[
0
]
return
output
[:,
:
self
.
num_heads
,
:]
def
forward
(
self
,
layer
:
AttentionLayer
,
q
:
torch
.
Tensor
,
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCMAiterMLASparseMetadata
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported for ROCMAiterMLASparse"
)
if
attn_metadata
is
None
:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return
output
.
fill_
(
0
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
q_nope
=
q_nope
.
transpose
(
0
,
1
)
if
self
.
is_fp8bmm_enabled
:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
ql_nope
=
rocm_aiter_ops
.
triton_fp8_bmm
(
q_nope
,
self
.
W_K
,
self
.
W_K_scale
,
group_size
=
128
,
transpose_bm
=
True
)
else
:
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope
=
torch
.
bmm
(
q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
ql_nope
=
ql_nope
.
transpose
(
0
,
1
)
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
topk_indices_global
=
triton_convert_req_index_to_global_index
(
attn_metadata
.
req_id_per_token
,
attn_metadata
.
block_table
,
topk_indices
,
BLOCK_SIZE
=
attn_metadata
.
block_size
,
NUM_TOPK_TOKENS
=
attn_metadata
.
topk_tokens
,
)
q
=
torch
.
cat
([
ql_nope
,
q_pe
],
dim
=-
1
)
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
layer
.
_k_scale
,
)
attn_out
=
self
.
_forward_bf16_kv
(
q
,
kv_cache
,
topk_indices_global
,
attn_metadata
)
self
.
_v_up_proj
(
attn_out
,
out
=
output
[:
num_actual_toks
])
return
output
vllm/v1/worker/utils.py
View file @
06c20c99
...
...
@@ -316,7 +316,7 @@ def bind_kv_cache(
# TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers
# in the same decoder block.
if
current_platform
.
is_cuda
()
or
current_platform
.
is_xpu
():
if
current_platform
.
is_cuda
_alike
()
or
current_platform
.
is_xpu
():
# We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment