Unverified Commit d9417096 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Batch invariant: Enable `TRITON_MLA` without prefix-caching (#29125)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 9d6235ca
......@@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
# enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported
......@@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
enable_prefix_caching=False,
)
prompt = "the capital of france is"
......@@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
......@@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
......@@ -928,7 +925,6 @@ def LLM_with_max_seqs(
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
# Enable for MOE models
# enable_expert_parallel=True,
)
......@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
}
tp_size = os.getenv("VLLM_TP_SIZE", "1")
server_args: list[str] = []
server_args: list[str] = [
"--max-model-len=8192",
"--max-num-seqs=32",
]
if tp_size:
server_args += ["-tp", tp_size]
......
......@@ -17,6 +17,7 @@ skip_unsupported = pytest.mark.skipif(
BACKENDS: list[str] = [
"FLASH_ATTN",
"TRITON_MLA",
]
if has_flashinfer():
......
......@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod,
......@@ -251,6 +252,24 @@ class Attention(nn.Module, AttentionLayerBase):
else:
self.attn_backend = attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
)
):
logger.warning_once(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
......@@ -628,6 +647,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla=True,
use_sparse=use_sparse,
)
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
......
......@@ -1006,11 +1006,11 @@ def override_envs_for_invariance():
"FLASH_ATTN", # best supported backend
"FLASHINFER",
"FLASH_ATTN_MLA",
"TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
# "TRITON_MLA",
]
if curr_attn_backend not in supported_backends:
error = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment