Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3416fe1
Commit
a3416fe1
authored
Oct 03, 2025
by
zhuwenwen
Browse files
add flashmla support
parent
b79e20fe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
39 deletions
+78
-39
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+23
-7
vllm/envs.py
vllm/envs.py
+6
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+49
-32
No files found.
vllm/attention/ops/flashmla.py
View file @
a3416fe1
...
...
@@ -27,14 +27,18 @@ if current_platform.is_cuda():
_flashmla_extension_C_AVAILABLE
=
False
else
:
_flashmla_extension_C_AVAILABLE
=
False
if
current_platform
.
is_rocm
():
import
flash_mla_cuda
_flashmla_C_AVAILABLE
=
True
def
is_flashmla_supported
()
->
Tuple
[
bool
,
Optional
[
str
]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
if
not
current_platform
.
is_cuda
():
return
False
,
"FlashMLA is
only
supported on CUDA devices."
if
not
(
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
())
:
return
False
,
"FlashMLA is supported on CUDA
and ROCM
devices."
if
current_platform
.
get_device_capability
()[
0
]
!=
9
:
return
False
,
"FlashMLA is only supported on Hopper devices."
if
not
_flashmla_C_AVAILABLE
:
...
...
@@ -71,11 +75,18 @@ def get_mla_metadata(
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
return
torch
.
ops
.
_flashmla_C
.
get_mla_decoding_metadata
(
if
current_platform
.
is_rocm
():
return
flash_mla_cuda
.
get_mla_metadata
(
cache_seqlens
,
num_q_tokens_per_head_k
,
num_heads_k
)
else
:
return
torch
.
ops
.
_flashmla_C
.
get_mla_decoding_metadata
(
cache_seqlens
,
num_q_tokens_per_head_k
,
num_heads_k
,
num_heads_q
,
is_fp8_kvcache
,
topk
)
def
flash_mla_with_kvcache
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
...
...
@@ -141,10 +152,15 @@ def flash_mla_with_kvcache(
q
,
k_cache
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
descale_q
,
descale_k
)
else
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
q
,
k_cache
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
is_fp8_kvcache
,
indices
)
if
current_platform
.
is_rocm
():
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
q
,
k_cache
,
block_table
,
cache_seqlens
,
head_dim_v
,
tile_scheduler_metadata
,
num_splits
,
softmax_scale
,
causal
)
else
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
q
,
k_cache
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
is_fp8_kvcache
,
indices
)
return
out
,
softmax_lse
...
...
vllm/envs.py
View file @
a3416fe1
...
...
@@ -204,6 +204,7 @@ if TYPE_CHECKING:
VLLM_USE_NCCL_SYMM_MEM
:
bool
=
False
VLLM_NCCL_INCLUDE_PATH
:
Optional
[
str
]
=
None
VLLM_USE_FBGEMM
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -1469,6 +1470,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
os
.
environ
.
get
(
"VLLM_NCCL_INCLUDE_PATH"
,
None
),
# Flag to enable FBGemm kernels on model execution
"VLLM_USE_FBGEMM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FBGEMM"
,
"0"
))),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"0"
))),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/platforms/rocm.py
View file @
a3416fe1
...
...
@@ -136,33 +136,34 @@ def use_rocm_custom_paged_attention(
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
)
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
ON_GFX9
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
,
"gfx950"
])
ON_GFX11_GFX12
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx11"
,
"gfx12"
])
#
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
#
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
#
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
if
ON_GFX9
:
return
((
not
envs
.
VLLM_USE_V1
or
sliding_window
==
0
or
sliding_window
==
(
-
1
,
-
1
))
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
128
*
1024
and
(
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
and
not
(
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
and
envs
.
VLLM_ROCM_USE_AITER
)
and
sinks
is
None
)
else
:
return
(
ON_GFX11_GFX12
and
(
not
envs
.
VLLM_USE_V1
or
sliding_window
==
0
or
sliding_window
==
(
-
1
,
-
1
))
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
head_size
==
128
and
block_size
==
16
and
(
gqa_ratio
>=
3
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
128
*
1024
and
alibi_slopes
is
None
and
kv_cache_dtype
==
"auto"
and
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
and
sinks
is
None
)
# if ON_GFX9:
# return ((not envs.VLLM_USE_V1 or sliding_window == 0
# or sliding_window == (-1, -1))
# and (qtype == torch.half or qtype == torch.bfloat16)
# and (head_size == 64 or head_size == 128)
# and (block_size == 16 or block_size == 32)
# and (gqa_ratio >= 1 and gqa_ratio <= 16)
# and max_seq_len <= 128 * 1024
# and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
# and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
# and envs.VLLM_ROCM_USE_AITER) and sinks is None)
# else:
# return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
# or sliding_window == (-1, -1))
# and (qtype == torch.half or qtype == torch.bfloat16)
# and head_size == 128 and block_size == 16
# and (gqa_ratio >= 3 and gqa_ratio <= 16)
# and max_seq_len <= 128 * 1024 and alibi_slopes is None
# and kv_cache_dtype == "auto"
# and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
return
False
class
RocmPlatform
(
Platform
):
...
...
@@ -222,14 +223,15 @@ class RocmPlatform(Platform):
raise
ValueError
(
f
" The selected backend,
{
selected_backend
.
name
}
,"
f
"does not support block size
{
block_size
}
."
)
if
selected_backend
==
_Backend
.
ROCM_AITER_MLA
:
if
block_size
==
1
:
logger
.
info
(
"Using AITER MLA backend on V1 engine."
)
return
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
# noqa: E501
raise
ValueError
(
f
" The selected backend,
{
selected_backend
.
name
}
,"
f
"does not support block size
{
block_size
}
."
"(currently only supports block size 1)"
)
# if selected_backend == _Backend.ROCM_AITER_MLA:
# if block_size == 1:
# logger.info("Using AITER MLA backend on V1 engine.")
# return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
# raise ValueError(
# f" The selected backend, {selected_backend.name},"
# f"does not support block size {block_size}."
# "(currently only supports block size 1)")
raise
ValueError
(
f
" The selected backend,
{
selected_backend
.
name
}
,"
f
"is not MLA type while requested for MLA backend."
)
...
...
@@ -249,6 +251,21 @@ class RocmPlatform(Platform):
logger
.
info
(
"Using Rocm/Aiter Attention backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
"rocm_attn.RocmAttentionBackend"
)
if
envs
.
VLLM_USE_FLASH_MLA
:
from
vllm.attention.ops.flashmla
import
is_flashmla_supported
use_flashmla
=
selected_backend
==
_Backend
.
FLASHMLA
or
(
selected_backend
is
None
and
is_flashmla_supported
()[
0
])
if
use_flashmla
:
if
block_size
!=
64
:
logger
.
warning
(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64)."
,
block_size
)
else
:
logger
.
info_once
(
"Using FlashMLA backend on V1 engine."
)
return
(
"vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend"
)
else
:
# default case, using triton unified attention
logger
.
info
(
"Using Triton Attention backend on V1 engine."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment