Unverified Commit eaf49786 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Test] Only Run MLA model when user explicitly set for batch invariance (#37719)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 77d24c4b
...@@ -8,10 +8,10 @@ import pytest ...@@ -8,10 +8,10 @@ import pytest
import torch import torch
from utils import ( from utils import (
BACKENDS, BACKENDS,
TEST_MODEL,
_extract_step_logprobs, _extract_step_logprobs,
_random_prompt, _random_prompt,
is_device_capability_below_90, is_device_capability_below_90,
resolve_model_name,
skip_unsupported, skip_unsupported,
) )
...@@ -57,7 +57,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -57,7 +57,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
attention_config = {"backend": backend} attention_config = {"backend": backend}
# Allow overrides from environment (useful for CI tuning) # Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism # "facebook/opt-125m" is too small, doesn't reliably test determinism
model = resolve_model_name(backend) model = TEST_MODEL
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128")) max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
...@@ -169,7 +169,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -169,7 +169,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
): ):
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
# For batch invariance, disable custom all-reduce to ensure deterministic # For batch invariance, disable custom all-reduce to ensure deterministic
...@@ -186,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -186,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
llm = LLM( llm = LLM(
model=model_name, model=TEST_MODEL,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
max_num_seqs=128, max_num_seqs=128,
max_model_len=8192, max_model_len=8192,
...@@ -395,7 +394,7 @@ def test_simple_generation(backend): ...@@ -395,7 +394,7 @@ def test_simple_generation(backend):
Simple test that runs the model with a basic prompt and prints the output. Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging. Useful for quick smoke testing and debugging.
""" """
model = resolve_model_name(backend) model = TEST_MODEL
llm = LLM( llm = LLM(
model=model, model=model,
...@@ -458,7 +457,6 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -458,7 +457,6 @@ def test_logprobs_without_batch_invariance_should_fail(
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
...@@ -466,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -466,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
llm = LLM( llm = LLM(
model=model_name, model=TEST_MODEL,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
...@@ -674,7 +672,6 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -674,7 +672,6 @@ def test_decode_logprobs_match_prefill_logprobs(
""" """
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
...@@ -689,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -689,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs(
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
llm = LLM( llm = LLM(
model=model_name, model=TEST_MODEL,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
......
...@@ -17,7 +17,7 @@ from typing import Any ...@@ -17,7 +17,7 @@ from typing import Any
import openai import openai
import pytest import pytest
from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported from utils import BACKENDS, TEST_MODEL, _random_prompt, skip_unsupported
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
...@@ -139,7 +139,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -139,7 +139,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str, backend: str,
) -> None: ) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)] prompts_all = [_random_prompt(10, 50) for _ in range(32)]
sp_kwargs: dict[str, Any] = { sp_kwargs: dict[str, Any] = {
...@@ -159,11 +158,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -159,11 +158,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
if tp_size: if tp_size:
server_args += ["-tp", tp_size] server_args += ["-tp", tp_size]
with RemoteOpenAIServer(model_name, server_args) as server: with RemoteOpenAIServer(TEST_MODEL, server_args) as server:
client = server.get_client() client = server.get_client()
_compare_bs1_vs_bsn_single_process( _compare_bs1_vs_bsn_single_process(
prompts=prompts_all, prompts=prompts_all,
sp_kwargs=sp_kwargs, sp_kwargs=sp_kwargs,
client=client, client=client,
model_name=model_name, model_name=TEST_MODEL,
) )
...@@ -7,6 +7,10 @@ import pytest ...@@ -7,6 +7,10 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.model_arch_config_convertor import (
ModelArchConfigConvertorBase,
)
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
skip_unsupported = pytest.mark.skipif( skip_unsupported = pytest.mark.skipif(
...@@ -16,10 +20,12 @@ skip_unsupported = pytest.mark.skipif( ...@@ -16,10 +20,12 @@ skip_unsupported = pytest.mark.skipif(
reason="Requires CUDA and >= Ampere (SM80)", reason="Requires CUDA and >= Ampere (SM80)",
) )
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
TEST_MODEL = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
BACKENDS: list[str] = [ BACKENDS: list[str] = [
"FLASH_ATTN", "FLASH_ATTN",
"TRITON_ATTN", "TRITON_ATTN",
"TRITON_MLA",
] ]
# FlashInfer temporarily disabled due to invariant CTA sizes. # FlashInfer temporarily disabled due to invariant CTA sizes.
...@@ -27,20 +33,14 @@ BACKENDS: list[str] = [ ...@@ -27,20 +33,14 @@ BACKENDS: list[str] = [
# if has_flashinfer(): # if has_flashinfer():
# BACKENDS.append("FLASHINFER") # BACKENDS.append("FLASHINFER")
if flash_attn_supports_mla(): # only run MLA backends when the requested test model is itself an MLA model.
if os.getenv("VLLM_TEST_MODEL"):
config = get_config(TEST_MODEL, trust_remote_code=False)
if ModelArchConfigConvertorBase(config, config.get_text_config()).is_deepseek_mla():
BACKENDS = ["TRITON_MLA"]
if flash_attn_supports_mla():
BACKENDS.append("FLASH_ATTN_MLA") BACKENDS.append("FLASH_ATTN_MLA")
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
def resolve_model_name(backend: str) -> str:
"""Resolve the model name for the given backend."""
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
if backend.endswith("MLA") and model == DEFAULT_MODEL:
return MLA_MODEL
return model
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens # Generate more realistic prompts that will actually produce varied tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment