Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9417096
Unverified
Commit
d9417096
authored
Dec 08, 2025
by
Wentao Ye
Committed by
GitHub
Dec 08, 2025
Browse files
[Feature] Batch invariant: Enable `TRITON_MLA` without prefix-caching (#29125)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
9d6235ca
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
7 deletions
+43
-7
tests/v1/determinism/test_batch_invariance.py
tests/v1/determinism/test_batch_invariance.py
+1
-5
tests/v1/determinism/test_online_batch_invariance.py
tests/v1/determinism/test_online_batch_invariance.py
+4
-1
tests/v1/determinism/utils.py
tests/v1/determinism/utils.py
+1
-0
vllm/attention/layer.py
vllm/attention/layer.py
+36
-0
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+1
-1
No files found.
tests/v1/determinism/test_batch_invariance.py
View file @
d9417096
...
@@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
#
enable_prefix_caching=False,
max_num_seqs
=
32
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
# not everything is supported
dtype
=
"bfloat16"
,
# not everything is supported
...
@@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
...
@@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
0.9
,
max_model_len
=
2048
,
max_model_len
=
2048
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
enable_prefix_caching
=
False
,
)
)
prompt
=
"the capital of france is"
prompt
=
"the capital of france is"
...
@@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
...
@@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
max_num_seqs
=
32
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
...
@@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
...
@@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
llm
=
LLM
(
llm
=
LLM
(
model
=
model_name
,
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
max_num_seqs
=
32
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
...
@@ -928,7 +925,6 @@ def LLM_with_max_seqs(
...
@@ -928,7 +925,6 @@ def LLM_with_max_seqs(
max_model_len
=
max_model_len
,
max_model_len
=
max_model_len
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
enable_prefix_caching
=
False
,
# Enable for MOE models
# Enable for MOE models
# enable_expert_parallel=True,
# enable_expert_parallel=True,
)
)
tests/v1/determinism/test_online_batch_invariance.py
View file @
d9417096
...
@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
}
}
tp_size
=
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)
tp_size
=
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)
server_args
:
list
[
str
]
=
[]
server_args
:
list
[
str
]
=
[
"--max-model-len=8192"
,
"--max-num-seqs=32"
,
]
if
tp_size
:
if
tp_size
:
server_args
+=
[
"-tp"
,
tp_size
]
server_args
+=
[
"-tp"
,
tp_size
]
...
...
tests/v1/determinism/utils.py
View file @
d9417096
...
@@ -17,6 +17,7 @@ skip_unsupported = pytest.mark.skipif(
...
@@ -17,6 +17,7 @@ skip_unsupported = pytest.mark.skipif(
BACKENDS
:
list
[
str
]
=
[
BACKENDS
:
list
[
str
]
=
[
"FLASH_ATTN"
,
"FLASH_ATTN"
,
"TRITON_MLA"
,
]
]
if
has_flashinfer
():
if
has_flashinfer
():
...
...
vllm/attention/layer.py
View file @
d9417096
...
@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
...
@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.batch_invariant
import
vllm_is_batch_invariant
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
UnquantizedLinearMethod
,
UnquantizedLinearMethod
,
...
@@ -251,6 +252,24 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -251,6 +252,24 @@ class Attention(nn.Module, AttentionLayerBase):
else
:
else
:
self
.
attn_backend
=
attn_backend
self
.
attn_backend
=
attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
and
vllm_is_batch_invariant
()
and
(
self
.
attn_backend
.
get_name
()
==
"FLASHINFER"
or
self
.
attn_backend
.
get_name
()
==
"TRITON_MLA"
)
):
logger
.
warning_once
(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported."
,
scope
=
"local"
,
)
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
self
.
attn_backend
.
get_impl_cls
()
impl_cls
=
self
.
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
self
.
impl
=
impl_cls
(
num_heads
,
num_heads
,
...
@@ -628,6 +647,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -628,6 +647,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla
=
True
,
use_mla
=
True
,
use_sparse
=
use_sparse
,
use_sparse
=
use_sparse
,
)
)
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
and
vllm_is_batch_invariant
()
and
(
self
.
attn_backend
.
get_name
()
==
"TRITON_MLA"
or
self
.
attn_backend
.
get_name
()
==
"FLASHINFER"
)
):
logger
.
warning_once
(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported."
,
scope
=
"local"
,
)
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
cast
(
type
[
MLAAttentionImpl
],
self
.
attn_backend
.
get_impl_cls
())
impl_cls
=
cast
(
type
[
MLAAttentionImpl
],
self
.
attn_backend
.
get_impl_cls
())
self
.
impl
=
impl_cls
(
self
.
impl
=
impl_cls
(
num_heads
=
self
.
num_heads
,
num_heads
=
self
.
num_heads
,
...
...
vllm/model_executor/layers/batch_invariant.py
View file @
d9417096
...
@@ -1006,11 +1006,11 @@ def override_envs_for_invariance():
...
@@ -1006,11 +1006,11 @@ def override_envs_for_invariance():
"FLASH_ATTN"
,
# best supported backend
"FLASH_ATTN"
,
# best supported backend
"FLASHINFER"
,
"FLASHINFER"
,
"FLASH_ATTN_MLA"
,
"FLASH_ATTN_MLA"
,
"TRITON_MLA"
,
# Not yet supported MLA backends
# Not yet supported MLA backends
# "FLASHMLA",
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
# "TRITON_MLA",
]
]
if
curr_attn_backend
not
in
supported_backends
:
if
curr_attn_backend
not
in
supported_backends
:
error
=
(
error
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment