Unverified Commit 7eb6cb6c authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Update tests to remove deprecated env vars (#30563)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 9ca8cb38
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref import weakref
from contextlib import ExitStack from contextlib import ExitStack
...@@ -13,26 +11,6 @@ from vllm import LLM ...@@ -13,26 +11,6 @@ from vllm import LLM
from vllm.config import CompilationConfig, CompilationMode from vllm.config import CompilationConfig, CompilationMode
from vllm.platforms import current_platform from vllm.platforms import current_platform
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
# test attention backend and cudagraph_mode combo # test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported) # (backend_name, cudagraph_mode, supported)
if current_platform.is_rocm(): if current_platform.is_rocm():
...@@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte ...@@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
): ):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
env_vars = backend_configs[backend_name].env_vars attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack: with ExitStack() as stack:
if not supported: if not supported:
stack.enter_context(pytest.raises(Exception)) stack.enter_context(pytest.raises(Exception))
...@@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte ...@@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.45,
max_model_len=1024, max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
), ),
...@@ -122,9 +101,10 @@ combo_cases_2 = [ ...@@ -122,9 +101,10 @@ combo_cases_2 = [
def test_cudagraph_compilation_combo( def test_cudagraph_compilation_combo(
backend_name, cudagraph_mode, compilation_mode, supported backend_name, cudagraph_mode, compilation_mode, supported
): ):
env_vars = backend_configs[backend_name].env_vars backend_config = backend_configs[backend_name]
attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack: with ExitStack() as stack:
if not supported: if not supported:
stack.enter_context(pytest.raises(Exception)) stack.enter_context(pytest.raises(Exception))
...@@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo( ...@@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.45,
max_model_len=1024, max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=compilation_mode, cudagraph_mode=cudagraph_mode mode=compilation_mode, cudagraph_mode=cudagraph_mode
), ),
......
...@@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90() ...@@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
BACKENDS, BACKENDS,
) )
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
""" """
Ensures that the same request (the 'needle' prompt) yields identical output Ensures that the same request (the 'needle' prompt) yields identical output
...@@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) attention_config = {"backend": backend}
# Allow overrides from environment (useful for CI tuning) # Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism # "facebook/opt-125m" is too small, doesn't reliably test determinism
model = resolve_model_name(backend) model = resolve_model_name(backend)
...@@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
attention_config=attention_config,
) )
# Baseline generation for the needle prompt alone. # Baseline generation for the needle prompt alone.
...@@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
attention_config=attention_config,
) )
mismatches = 0 mismatches = 0
...@@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
BACKENDS, BACKENDS,
) )
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
...@@ -193,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -193,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
dtype="bfloat16", # not everything is supported dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# Use more realistic prompts for better token generation # Use more realistic prompts for better token generation
...@@ -381,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -381,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
"backend", "backend",
BACKENDS, BACKENDS,
) )
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): def test_simple_generation(backend):
""" """
Simple test that runs the model with a basic prompt and prints the output. Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging. Useful for quick smoke testing and debugging.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = resolve_model_name(backend) model = resolve_model_name(backend)
llm = LLM( llm = LLM(
...@@ -398,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): ...@@ -398,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
dtype="bfloat16", dtype="bfloat16",
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
prompt = "the capital of france is" prompt = "the capital of france is"
...@@ -444,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -444,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters). The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed). The test will FAIL if everything matches (suggesting batch invariance isn't needed).
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test # CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
...@@ -465,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -465,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# build ragged prompts to change shapes significantly across BS=1 vs BS=N # build ragged prompts to change shapes significantly across BS=1 vs BS=N
...@@ -649,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -649,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
def test_decode_logprobs_match_prefill_logprobs( def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
""" """
Test that verifies decode logprobs match prefill logprobs. Test that verifies decode logprobs match prefill logprobs.
...@@ -664,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -664,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
This ensures that the logprobs from decode are consistent with what This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix. we would get if we ran prefill on each prefix.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
...@@ -689,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -689,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# Use a few test prompts # Use a few test prompts
...@@ -920,6 +919,7 @@ def LLM_with_max_seqs( ...@@ -920,6 +919,7 @@ def LLM_with_max_seqs(
max_num_seqs: int, max_num_seqs: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
max_model_len: int, max_model_len: int,
attention_config: dict | None = None,
) -> LLM: ) -> LLM:
""" """
Helper to construct an LLM with a specific max_num_seqs (batch-size limit) Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
...@@ -934,6 +934,7 @@ def LLM_with_max_seqs( ...@@ -934,6 +934,7 @@ def LLM_with_max_seqs(
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config=attention_config,
# Enable for MOE models # Enable for MOE models
# enable_expert_parallel=True, # enable_expert_parallel=True,
) )
...@@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process( ...@@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("backend", BACKENDS)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str, monkeypatch: pytest.MonkeyPatch backend: str,
) -> None: ) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
# Override backend for this test (and the RemoteOpenAIServer child process).
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)] prompts_all = [_random_prompt(10, 50) for _ in range(32)]
...@@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
server_args: list[str] = [ server_args: list[str] = [
"--max-model-len=8192", "--max-model-len=8192",
"--max-num-seqs=32", "--max-num-seqs=32",
f"--attention-backend={backend}",
] ]
if tp_size: if tp_size:
server_args += ["-tp", tp_size] server_args += ["-tp", tp_size]
......
...@@ -142,16 +142,17 @@ def run_tests( ...@@ -142,16 +142,17 @@ def run_tests(
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding.""" uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m: # Determine attention config based on platform
# avoid precision errors
if current_platform.is_rocm(): if current_platform.is_rocm():
if is_testing_with_spec_decoding: if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency # Use TRITON_ATTN for spec decoding test for consistency
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") attention_config = {"backend": "TRITON_ATTN"}
else: else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA") attention_config = {"backend": "ROCM_AITER_FA"}
else: else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE) # lock matmul precision to full FP32 (IEEE)
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee") m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
# m.setenv("VLLM_BATCH_INVARIANT", "1") # m.setenv("VLLM_BATCH_INVARIANT", "1")
...@@ -174,6 +175,7 @@ def run_tests( ...@@ -174,6 +175,7 @@ def run_tests(
spec_config, spec_config,
test_prefill_chunking=test_prefill_chunking, test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding, is_testing_with_spec_decoding=is_testing_with_spec_decoding,
attention_config=attention_config,
) )
outputs.append(test_results) outputs.append(test_results)
...@@ -262,6 +264,7 @@ def run_test( ...@@ -262,6 +264,7 @@ def run_test(
spec_config: dict[str, Any] | None, spec_config: dict[str, Any] | None,
test_prefill_chunking: bool, test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False, is_testing_with_spec_decoding: bool = False,
attention_config: dict[str, Any] | None = None,
): ):
spec_decoding = spec_config is not None spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = ( cache_arg: dict[str, Any] = (
...@@ -301,6 +304,7 @@ def run_test( ...@@ -301,6 +304,7 @@ def run_test(
dtype=dtype, dtype=dtype,
speculative_config=spec_config, speculative_config=spec_config,
disable_log_stats=False, disable_log_stats=False,
attention_config=attention_config,
**cache_arg, **cache_arg,
) as vllm_model: ) as vllm_model:
results = [] results = []
......
...@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test ...@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"]) @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend): def test_cascade_attention(example_system_message, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER": if attn_backend == "FLASHINFER":
...@@ -19,10 +19,9 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend): ...@@ -19,10 +19,9 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
"needs investigation. See issue #25679." "needs investigation. See issue #25679."
) )
with monkeypatch.context() as m: llm = LLM(
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
sampling_params = SamplingParams(temperature=0.0, max_tokens=100) sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
# No cascade attention. # No cascade attention.
......
...@@ -438,19 +438,17 @@ def test_eagle_correctness( ...@@ -438,19 +438,17 @@ def test_eagle_correctness(
should be the same when using eagle speculative decoding. should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size) model_setup: (method, model_name, eagle_model_name, tp_size)
""" """
with monkeypatch.context() as m: # Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# to Flex Attn
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# pass if not ROCm
if current_platform.is_rocm(): if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm # TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently") pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
attention_config = None # Let it fall back to default
else: else:
m.setenv("VLLM_MLA_DISABLE", "1") attention_config = {"backend": attn_backend}
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip( pytest.skip(
...@@ -458,6 +456,9 @@ def test_eagle_correctness( ...@@ -458,6 +456,9 @@ def test_eagle_correctness(
"multi-token eagle spec decode on current platform" "multi-token eagle spec decode on current platform"
) )
with monkeypatch.context() as m:
m.setenv("VLLM_MLA_DISABLE", "1")
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower(): if "deepseek" in model_setup[1].lower():
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform") pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
...@@ -471,7 +472,10 @@ def test_eagle_correctness( ...@@ -471,7 +472,10 @@ def test_eagle_correctness(
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM( ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size,
attention_config=attention_config,
) )
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
...@@ -492,6 +496,7 @@ def test_eagle_correctness( ...@@ -492,6 +496,7 @@ def test_eagle_correctness(
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl, model_impl=model_impl,
attention_config=attention_config,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
......
...@@ -3,21 +3,29 @@ set -xe ...@@ -3,21 +3,29 @@ set -xe
# Parse command line arguments # Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case $1 in case $1 in
--kv_buffer_device) --kv_buffer_device)
KV_BUFFER_DEVICE="$2" KV_BUFFER_DEVICE="$2"
shift 2 shift 2
;; ;;
--attention-backend)
ATTENTION_BACKEND="$2"
shift 2
;;
*) *)
echo "Unknown option $1" echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]" echo "Usage: $0 [--kv_buffer_device <cuda|cpu>] [--attention-backend <backend>]"
exit 1 exit 1
;; ;;
esac esac
done done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
...@@ -148,6 +156,11 @@ run_tests_for_model() { ...@@ -148,6 +156,11 @@ run_tests_for_model() {
--tensor-parallel-size $PREFILLER_TP_SIZE \ --tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
if [ -n "$model_args" ]; then if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args" FULL_CMD="$BASE_CMD $model_args"
else else
...@@ -189,6 +202,11 @@ run_tests_for_model() { ...@@ -189,6 +202,11 @@ run_tests_for_model() {
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# DP-EP attention mode # DP-EP attention mode
if [[ -z "$DP_EP" ]]; then if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE" BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
......
...@@ -15,14 +15,14 @@ configs=( ...@@ -15,14 +15,14 @@ configs=(
run_tests() { run_tests() {
local label=$1 local label=$1
local extra_env=$2 local extra_args=$2
echo "=== Running tests (${label}) ===" echo "=== Running tests (${label}) ==="
for cfg in "${configs[@]}"; do for cfg in "${configs[@]}"; do
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}" echo "-> Running with ${cfg} ${extra_args:+and ${extra_args}}"
# Use 'env' to safely set variables without eval # Use 'env' to safely set variables without eval
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then if ! env ${cfg} bash "${SCRIPT}" ${extra_args}; then
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}" echo "❌ Test failed for config: ${cfg} ${extra_args:+(${extra_args})}"
exit 1 exit 1
fi fi
done done
...@@ -34,8 +34,8 @@ run_tests "default backend" "" ...@@ -34,8 +34,8 @@ run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty) # Check if FLASHINFER is set (non-empty)
if [[ -n "${FLASHINFER:-}" ]]; then if [[ -n "${FLASHINFER:-}" ]]; then
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER" echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER"
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER" run_tests "FLASHINFER backend" "--attention-backend FLASHINFER"
else else
echo "FLASHINFER not set, skipping FLASHINFER runs." echo "FLASHINFER not set, skipping FLASHINFER runs."
fi fi
...@@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): ...@@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN", "TRITON_ATTN",
], ],
) )
def test_register_kv_caches(dist_init, attn_backend, monkeypatch): def test_register_kv_caches(dist_init, attn_backend):
""" """
Test that register_kv_caches() properly calls nixl_wrapper methods with Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data. correct data.
...@@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): ...@@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
block layout info block layout info
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) vllm_config = create_vllm_config(attention_backend=attn_backend)
vllm_config = create_vllm_config()
# Import the appropriate backend based on the parameter # Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN": if attn_backend == "FLASH_ATTN":
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
KVTransferConfig, KVTransferConfig,
...@@ -94,6 +95,7 @@ def create_vllm_config( ...@@ -94,6 +95,7 @@ def create_vllm_config(
dtype: str = "float16", dtype: str = "float16",
cache_dtype: str = "auto", cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None, hf_overrides: dict[str, Any] | None = None,
attention_backend: str | None = None,
) -> VllmConfig: ) -> VllmConfig:
"""Initialize VllmConfig For Testing.""" """Initialize VllmConfig For Testing."""
model_config = ModelConfig( model_config = ModelConfig(
...@@ -124,12 +126,14 @@ def create_vllm_config( ...@@ -124,12 +126,14 @@ def create_vllm_config(
enable_permute_local_kv=enable_permute_local_kv, enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {}, kv_connector_extra_config=kv_connector_extra_config or {},
) )
attention_config = AttentionConfig(backend=attention_backend)
return VllmConfig( return VllmConfig(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"), device_config=DeviceConfig("cpu"),
attention_config=attention_config,
) )
......
...@@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt ...@@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [48] CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN"] ATTN_BACKENDS = ["FLASH_ATTN"]
...@@ -180,12 +179,12 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None: ...@@ -180,12 +179,12 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic="test", topic="test",
) )
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
llm = LLM( llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct", model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5, gpu_memory_utilization=0.5,
kv_events_config=kv_events_config, kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
attention_config={"backend": attn_backend},
) )
events_endpoint = events_endpoint.replace("*", "127.0.0.1") events_endpoint = events_endpoint.replace("*", "127.0.0.1")
......
...@@ -15,6 +15,7 @@ from tests.v1.attention.utils import ( ...@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
ModelConfig, ModelConfig,
...@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" ...@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def _create_proposer( def _create_proposer(
method: str, method: str,
num_speculative_tokens: int, num_speculative_tokens: int,
attention_backend: str | None = None,
speculative_token_tree: list[tuple[int, ...]] | None = None, speculative_token_tree: list[tuple[int, ...]] | None = None,
) -> EagleProposer: ) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
...@@ -70,6 +72,7 @@ def _create_proposer( ...@@ -70,6 +72,7 @@ def _create_proposer(
max_model_len=model_config.max_model_len, max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder, is_encoder_decoder=model_config.is_encoder_decoder,
), ),
attention_config=AttentionConfig(backend=attention_backend),
) )
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
...@@ -331,8 +334,6 @@ def test_load_model( ...@@ -331,8 +334,6 @@ def test_load_model(
use_distinct_lm_head, use_distinct_lm_head,
monkeypatch, monkeypatch,
): ):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip( pytest.skip(
"TRITON_ATTN does not support " "TRITON_ATTN does not support "
...@@ -394,7 +395,9 @@ def test_load_model( ...@@ -394,7 +395,9 @@ def test_load_model(
assert not isinstance(target_model, SupportsMultiModal) assert not isinstance(target_model, SupportsMultiModal)
# Create proposer using the helper function # Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8) proposer = _create_proposer(
method, num_speculative_tokens=8, attention_backend=attn_backend
)
# Call the method under test # Call the method under test
proposer.load_model(target_model) proposer.load_model(target_model)
...@@ -420,8 +423,6 @@ def test_load_model( ...@@ -420,8 +423,6 @@ def test_load_model(
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip( pytest.skip(
"TRITON_ATTN does not support " "TRITON_ATTN does not support "
...@@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens = [seq_len_1, seq_len_2] seq_lens = [seq_len_1, seq_len_2]
# Create proposer first so we can use its actual hidden_size # Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens) proposer = _create_proposer(
"eagle", num_speculative_tokens, attention_backend=attn_backend
)
# Get the hidden_size from the proposer to ensure consistency # Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size hidden_size = proposer.hidden_size
...@@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree): ...@@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size. # Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer( proposer = _create_proposer(
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree "eagle",
num_speculative_tokens,
speculative_token_tree=spec_token_tree,
) )
# Get the hidden_size from the proposer to ensure consistency. # Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size hidden_size = proposer.hidden_size
......
...@@ -38,9 +38,6 @@ def test_ngram_max_len(num_speculative_tokens: int): ...@@ -38,9 +38,6 @@ def test_ngram_max_len(num_speculative_tokens: int):
def test_eagle_max_len( def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
): ):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip( pytest.skip(
"TRITON_ATTN does not support " "TRITON_ATTN does not support "
...@@ -48,7 +45,7 @@ def test_eagle_max_len( ...@@ -48,7 +45,7 @@ def test_eagle_max_len(
) )
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM( llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct", model="meta-llama/Meta-Llama-3-8B-Instruct",
...@@ -60,20 +57,18 @@ def test_eagle_max_len( ...@@ -60,20 +57,18 @@ def test_eagle_max_len(
"max_model_len": 80, "max_model_len": 80,
}, },
max_model_len=200, max_model_len=200,
attention_config={"backend": attn_backend},
) )
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True) sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params) outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs: for o in outputs:
assert o.outputs[0].finish_reason == "length", ( assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output " "This test is only meaningful if the output is truncated due to max length"
"is truncated due to max length"
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
max_tokens=200, max_tokens=200,
structured_outputs=StructuredOutputsParams( structured_outputs=StructuredOutputsParams(regex="^" + "a b c d e " * 15 + "$"),
regex="^" + "a b c d e " * 15 + "$"
),
) )
output = llm.generate(_PROMPTS, sampling_params) output = llm.generate(_PROMPTS, sampling_params)
for o in output: for o in output:
......
...@@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend): ...@@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. " f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set --attention-config.backend=FLEX_ATTENTION to use " "Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes." "FlexAttention backend which supports all head sizes."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment