Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96453cfa
Unverified
Commit
96453cfa
authored
Jul 01, 2025
by
TY-AMD
Committed by
GitHub
Jul 01, 2025
Browse files
[BugFix][V1][ROCm] Triton MLA uses V0 backend on V1 engine (#19067)
Signed-off-by:
Tianyuan Wu
<
Tianyuan.Wu@amd.com
>
parent
b1c1fe35
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
78 additions
and
10 deletions
+78
-10
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+2
-4
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+4
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+8
-2
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+7
-2
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+57
-0
No files found.
tests/kernels/attention/test_attention_selector.py
View file @
96453cfa
...
...
@@ -106,10 +106,8 @@ def test_env(
block_size
,
False
,
use_mla
=
use_mla
)
if
use_v1
and
name
!=
"TRITON_MLA"
:
assert
backend
.
get_name
()
==
f
"
{
name
}
_VLLM_V1"
else
:
assert
backend
.
get_name
()
==
name
expected
=
f
"
{
name
}
_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
expected
else
:
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
get_attn_backend
(
16
,
...
...
tests/kernels/attention/test_rocm_attention_selector.py
View file @
96453cfa
...
...
@@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"TRITON_MLA"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
False
,
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
# If attention backend is None
# If use_mla is true
...
...
@@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
None
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
False
,
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
# change the attention backend to AITER MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_AITER_MLA"
)
...
...
vllm/platforms/rocm.py
View file @
96453cfa
...
...
@@ -186,8 +186,14 @@ class RocmPlatform(Platform):
if
selected_backend
==
_Backend
.
TRITON_MLA
:
if
block_size
!=
1
:
logger
.
info
(
"Using Triton MLA backend."
)
return
"vllm.attention.backends.triton_mla.TritonMLABackend"
# noqa: E501
if
use_v1
:
logger
.
info_once
(
"Using Triton MLA backend on V1 engine."
)
return
(
"vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend"
)
else
:
logger
.
info
(
"Using Triton MLA backend."
)
return
"vllm.attention.backends.triton_mla.TritonMLABackend"
# noqa: E501
else
:
raise
ValueError
(
f
" The selected backend,
{
selected_backend
.
name
}
,"
...
...
vllm/v1/attention/backends/mla/common.py
View file @
96453cfa
...
...
@@ -640,7 +640,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self
.
qk_head_dim
=
qk_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
kv_b_proj
=
kv_b_proj
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
...
...
@@ -672,11 +671,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_vllm_fa
:
kwargs
[
"return_softmax_lse"
]
=
return_softmax_lse
else
:
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs
[
"return_attn_probs"
]
=
return_softmax_lse
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
96453cfa
...
...
@@ -5,10 +5,14 @@ from typing import Any, Optional
import
torch
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
)
...
...
@@ -68,6 +72,59 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
raise
NotImplementedError
(
"TritonMLA V1 with FP8 KV cache not yet supported"
)
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
self
.
triton_fa_func
=
triton_attention
if
HAS_TRITON
else
None
def
_flash_attn_varlen_diff_headdims_rocm
(
self
,
q
,
k
,
v
,
softmax_scale
=
None
,
**
kwargs
):
assert
self
.
triton_fa_func
is
not
None
# Triton Attention requires a padded V
padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
# The output of triton_attention is a tuple of
# [output_tensor, encoded_softmax] where encoded_softmax is always None
output_tensor
,
_
=
self
.
triton_fa_func
(
q
,
k
,
padded_v
,
None
,
# output
kwargs
[
"cu_seqlens_q"
],
kwargs
[
"cu_seqlens_k"
],
kwargs
[
"max_seqlen_q"
],
kwargs
[
"max_seqlen_k"
],
kwargs
[
"causal"
],
softmax_scale
,
None
,
# bias
)
return
output_tensor
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
**
kwargs
):
if
current_platform
.
is_rocm
()
\
and
self
.
use_triton_flash_attn
\
and
not
return_softmax_lse
:
return
self
.
_flash_attn_varlen_diff_headdims_rocm
(
q
,
k
,
v
,
softmax_scale
=
softmax_scale
,
**
kwargs
)
else
:
return
super
().
_flash_attn_varlen_diff_headdims
(
q
,
k
,
v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment