Unverified Commit d9417096 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Batch invariant: Enable `TRITON_MLA` without prefix-caching (#29125)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 9d6235ca
...@@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
enable_prefix_caching=False, # enable_prefix_caching=False,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", # not everything is supported dtype="bfloat16", # not everything is supported
...@@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): ...@@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
max_model_len=2048, max_model_len=2048,
dtype="bfloat16", dtype="bfloat16",
enable_prefix_caching=False,
) )
prompt = "the capital of france is" prompt = "the capital of france is"
...@@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
...@@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
...@@ -928,7 +925,6 @@ def LLM_with_max_seqs( ...@@ -928,7 +925,6 @@ def LLM_with_max_seqs(
max_model_len=max_model_len, max_model_len=max_model_len,
dtype="bfloat16", dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
# Enable for MOE models # Enable for MOE models
# enable_expert_parallel=True, # enable_expert_parallel=True,
) )
...@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
} }
tp_size = os.getenv("VLLM_TP_SIZE", "1") tp_size = os.getenv("VLLM_TP_SIZE", "1")
server_args: list[str] = [] server_args: list[str] = [
"--max-model-len=8192",
"--max-num-seqs=32",
]
if tp_size: if tp_size:
server_args += ["-tp", tp_size] server_args += ["-tp", tp_size]
......
...@@ -17,6 +17,7 @@ skip_unsupported = pytest.mark.skipif( ...@@ -17,6 +17,7 @@ skip_unsupported = pytest.mark.skipif(
BACKENDS: list[str] = [ BACKENDS: list[str] = [
"FLASH_ATTN", "FLASH_ATTN",
"TRITON_MLA",
] ]
if has_flashinfer(): if has_flashinfer():
......
...@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig ...@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
UnquantizedLinearMethod, UnquantizedLinearMethod,
...@@ -251,6 +252,24 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -251,6 +252,24 @@ class Attention(nn.Module, AttentionLayerBase):
else: else:
self.attn_backend = attn_backend self.attn_backend = attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
)
):
logger.warning_once(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls() impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls( self.impl = impl_cls(
num_heads, num_heads,
...@@ -628,6 +647,23 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -628,6 +647,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla=True, use_mla=True,
use_sparse=use_sparse, use_sparse=use_sparse,
) )
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls( self.impl = impl_cls(
num_heads=self.num_heads, num_heads=self.num_heads,
......
...@@ -1006,11 +1006,11 @@ def override_envs_for_invariance(): ...@@ -1006,11 +1006,11 @@ def override_envs_for_invariance():
"FLASH_ATTN", # best supported backend "FLASH_ATTN", # best supported backend
"FLASHINFER", "FLASHINFER",
"FLASH_ATTN_MLA", "FLASH_ATTN_MLA",
"TRITON_MLA",
# Not yet supported MLA backends # Not yet supported MLA backends
# "FLASHMLA", # "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967 # "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
# "TRITON_MLA",
] ]
if curr_attn_backend not in supported_backends: if curr_attn_backend not in supported_backends:
error = ( error = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment