Unverified Commit 8f5d5120 authored by Frank Wang's avatar Frank Wang Committed by GitHub
Browse files

Disable Cascade Attention for Batch Invariance (#32561)


Signed-off-by: default avatarfrankwang28 <frank.wbb@hotmail.com>
Signed-off-by: default avatarFrank Wang <41319051+frankwang28@users.noreply.github.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent ae5b7aff
...@@ -188,7 +188,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -188,7 +188,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
max_num_seqs=32, max_num_seqs=128,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", # not everything is supported dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
...@@ -197,12 +197,20 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -197,12 +197,20 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
) )
# Use more realistic prompts for better token generation # Use more realistic prompts for better token generation
prompts = [_random_prompt(10, 50) for i in range(32)] prompts = [_random_prompt(10, 50) for _ in range(32)]
# TODO: Update prompts to have ragged lengths in order to test chunked prefill
# The above tests are not currently long enough to exercise chunking.
# prompts = (
# [_random_prompt(10, 50) for _ in range(28)]
# + [_random_prompt(256, 512) for _ in range(50)]
# + [_random_prompt(2048, 4096) for _ in range(50)]
# )
sp = SamplingParams( sp = SamplingParams(
temperature=0.6, temperature=0.6,
top_p=1.0, top_p=1.0,
max_tokens=8, max_tokens=16,
seed=1234, seed=1234,
logprobs=5, logprobs=5,
) )
......
...@@ -7,7 +7,6 @@ import pytest ...@@ -7,7 +7,6 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
skip_unsupported = pytest.mark.skipif( skip_unsupported = pytest.mark.skipif(
...@@ -22,8 +21,10 @@ BACKENDS: list[str] = [ ...@@ -22,8 +21,10 @@ BACKENDS: list[str] = [
"TRITON_MLA", "TRITON_MLA",
] ]
if has_flashinfer(): # FlashInfer temporarily disabled due to invariant CTA sizes.
BACKENDS.append("FLASHINFER") # See FlashInfer issue #2424
# if has_flashinfer():
# BACKENDS.append("FLASHINFER")
if flash_attn_supports_mla(): if flash_attn_supports_mla():
BACKENDS.append("FLASH_ATTN_MLA") BACKENDS.append("FLASH_ATTN_MLA")
...@@ -78,9 +79,10 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: ...@@ -78,9 +79,10 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# For longer prompts, repeat context # For longer prompts, repeat context
padding_text = ( padding_text = (
" This is an interesting topic that deserves more explanation. " " This is an interesting topic that deserves more explanation. "
# TODO: Update to * (target_words // 10) to better align with word ratio
* (target_words // 50) * (target_words // 50)
) )
base_prompt = base_prompt + padding_text base_prompt = padding_text + base_prompt
return base_prompt return base_prompt
......
...@@ -959,6 +959,18 @@ class VllmConfig: ...@@ -959,6 +959,18 @@ class VllmConfig:
"when cudagraph_mode piecewise cudagraphs is used, " "when cudagraph_mode piecewise cudagraphs is used, "
f"cudagraph_mode={self.compilation_config.cudagraph_mode}" f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
) )
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
if (
self.model_config
and vllm_is_batch_invariant()
and not self.model_config.disable_cascade_attn
):
self.model_config.disable_cascade_attn = True
logger.warning_once(
"Disabling cascade attention when VLLM_BATCH_INVARIANT is enabled.",
scope="local",
)
if self.parallel_config.use_ubatching: if self.parallel_config.use_ubatching:
a2a_backend = self.parallel_config.all2all_backend a2a_backend = self.parallel_config.all2all_backend
......
...@@ -1005,7 +1005,9 @@ def override_envs_for_invariance( ...@@ -1005,7 +1005,9 @@ def override_envs_for_invariance(
): ):
supported_backends = [ supported_backends = [
AttentionBackendEnum.FLASH_ATTN, # best supported backend AttentionBackendEnum.FLASH_ATTN, # best supported backend
AttentionBackendEnum.FLASHINFER, # FlashInfer temporarily disabled due to invariant CTA sizes.
# See FlashInfer issue #2424
# AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN_MLA, AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.TRITON_MLA, AttentionBackendEnum.TRITON_MLA,
# Not yet supported MLA backends # Not yet supported MLA backends
......
...@@ -18,11 +18,18 @@ from vllm.distributed import ( ...@@ -18,11 +18,18 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.batch_invariant import (
linear_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.layers.utils import (
dispatch_unquantized_gemm,
is_layer_moe_router_gate,
)
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BasevLLMParameter, BasevLLMParameter,
BlockQuantScaleParameter, BlockQuantScaleParameter,
...@@ -236,6 +243,12 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -236,6 +243,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if (
vllm_is_batch_invariant()
and current_platform.is_cuda_alike()
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
):
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
......
...@@ -16,6 +16,20 @@ from vllm.utils.torch_utils import direct_register_custom_op ...@@ -16,6 +16,20 @@ from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
MOE_LAYER_ROUTER_GATE_SUFFIXES = {
"gate",
"router",
"router_gate",
"shared_expert_gate",
"expert_gate",
}
def is_layer_moe_router_gate(prefix: str) -> bool:
if not prefix:
return False
return prefix.rsplit(".", 1)[-1] in MOE_LAYER_ROUTER_GATE_SUFFIXES
def shuffle_weight(w: torch.Tensor) -> torch.Tensor: def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
# Shuffle weight along the last dimension so that # Shuffle weight along the last dimension so that
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment