Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb1d0df8
Commit
bb1d0df8
authored
Apr 10, 2025
by
zhuwenwen
Browse files
support flashmla backend
parent
23607ca0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
21 deletions
+75
-21
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+40
-18
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+35
-3
No files found.
vllm/attention/ops/flashmla.py
View file @
bb1d0df8
...
@@ -18,13 +18,16 @@ if current_platform.is_cuda():
...
@@ -18,13 +18,16 @@ if current_platform.is_cuda():
else
:
else
:
_flashmla_C_AVAILABLE
=
False
_flashmla_C_AVAILABLE
=
False
if
current_platform
.
is_rocm
():
import
flash_mla_cuda
_flashmla_C_AVAILABLE
=
True
def
is_flashmla_supported
()
->
Tuple
[
bool
,
Optional
[
str
]]:
def
is_flashmla_supported
()
->
Tuple
[
bool
,
Optional
[
str
]]:
"""
"""
Return: is_supported_flag, unsupported_reason (optional).
Return: is_supported_flag, unsupported_reason (optional).
"""
"""
if
not
current_platform
.
is_cuda
():
if
not
(
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
())
:
return
False
,
"FlashMLA is
only
supported on CUDA devices."
return
False
,
"FlashMLA is supported on CUDA
and ROCM
devices."
if
current_platform
.
get_device_capability
()[
0
]
!=
9
:
if
current_platform
.
get_device_capability
()[
0
]
!=
9
:
return
False
,
"FlashMLA is only supported on Hopper devices."
return
False
,
"FlashMLA is only supported on Hopper devices."
if
not
_flashmla_C_AVAILABLE
:
if
not
_flashmla_C_AVAILABLE
:
...
@@ -51,9 +54,14 @@ def get_mla_metadata(
...
@@ -51,9 +54,14 @@ def get_mla_metadata(
dtype torch.int32.
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
"""
return
torch
.
ops
.
_flashmla_C
.
get_mla_metadata
(
cache_seqlens
,
if
current_platform
.
is_rocm
():
num_heads_per_head_k
,
return
flash_mla_cuda
.
get_mla_metadata
(
cache_seqlens
,
num_heads_k
)
num_heads_per_head_k
,
num_heads_k
)
else
:
return
torch
.
ops
.
_flashmla_C
.
get_mla_metadata
(
cache_seqlens
,
num_heads_per_head_k
,
num_heads_k
)
def
flash_mla_with_kvcache
(
def
flash_mla_with_kvcache
(
...
@@ -87,18 +95,32 @@ def flash_mla_with_kvcache(
...
@@ -87,18 +95,32 @@ def flash_mla_with_kvcache(
"""
"""
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
if
current_platform
.
is_rocm
():
q
,
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
k_cache
,
q
,
None
,
k_cache
,
head_dim_v
,
None
,
cache_seqlens
,
head_dim_v
,
block_table
,
cache_seqlens
,
softmax_scale
,
block_table
,
causal
,
softmax_scale
,
tile_scheduler_metadata
,
causal
,
num_splits
,
tile_scheduler_metadata
,
)
num_splits
,
)
else
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
)
return
out
,
softmax_lse
return
out
,
softmax_lse
...
@@ -112,4 +134,4 @@ def flash_mla_with_kvcache(
...
@@ -112,4 +134,4 @@ def flash_mla_with_kvcache(
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return ....
# return ....
#
#
\ No newline at end of file
vllm/platforms/rocm.py
View file @
bb1d0df8
...
@@ -138,8 +138,40 @@ class RocmPlatform(Platform):
...
@@ -138,8 +138,40 @@ class RocmPlatform(Platform):
kv_cache_dtype
,
block_size
,
use_v1
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
)
->
str
:
use_mla
)
->
str
:
if
use_mla
:
if
use_mla
:
logger
.
info
(
"Using Triton MLA backend."
)
# logger.info("Using Triton MLA backend.")
return
"vllm.attention.backends.triton_mla.TritonMLABackend"
# return "vllm.attention.backends.triton_mla.TritonMLABackend"
if
selected_backend
==
_Backend
.
TRITON_MLA
or
block_size
!=
64
:
if
use_v1
:
logger
.
info_once
(
"Using Triton MLA backend on V1 engine."
)
return
(
"vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend"
)
else
:
logger
.
info
(
"Using Triton MLA backend."
)
return
"vllm.attention.backends.triton_mla.TritonMLABackend"
else
:
from
vllm.attention.backends.flashmla
import
(
is_flashmla_supported
)
if
not
is_flashmla_supported
()[
0
]:
logger
.
warning
(
"FlashMLA backend is not supported due to %s"
,
is_flashmla_supported
()[
1
])
elif
block_size
!=
64
:
logger
.
warning
(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64)."
,
block_size
)
else
:
if
use_v1
:
logger
.
info_once
(
"Using FlashMLA backend on V1 engine."
)
return
(
"vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend"
)
else
:
logger
.
info
(
"Using FlashMLA backend."
)
return
(
"vllm.attention.backends."
"flashmla.FlashMLABackend"
)
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
:
...
@@ -311,4 +343,4 @@ class RocmPlatform(Platform):
...
@@ -311,4 +343,4 @@ class RocmPlatform(Platform):
# We only enable custom allreduce for MI300 series
# We only enable custom allreduce for MI300 series
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
supported_archs
=
[
'gfx94'
]
supported_archs
=
[
'gfx94'
]
return
any
(
gfx
in
gcn_arch
for
gfx
in
supported_archs
)
return
any
(
gfx
in
gcn_arch
for
gfx
in
supported_archs
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment