Unverified Commit 7eb6cb6c authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Update tests to remove deprecated env vars (#30563)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 9ca8cb38
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack
......@@ -13,26 +11,6 @@ from vllm import LLM
from vllm.config import CompilationConfig, CompilationMode
from vllm.platforms import current_platform
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
if current_platform.is_rocm():
......@@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
env_vars = backend_configs[backend_name].env_vars
attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack:
with ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(Exception))
......@@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
trust_remote_code=True,
gpu_memory_utilization=0.45,
max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
),
......@@ -122,9 +101,10 @@ combo_cases_2 = [
def test_cudagraph_compilation_combo(
backend_name, cudagraph_mode, compilation_mode, supported
):
env_vars = backend_configs[backend_name].env_vars
backend_config = backend_configs[backend_name]
attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack:
with ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(Exception))
......@@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
trust_remote_code=True,
gpu_memory_utilization=0.45,
max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig(
mode=compilation_mode, cudagraph_mode=cudagraph_mode
),
......
......@@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
BACKENDS,
)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch
backend,
):
"""
Ensures that the same request (the 'needle' prompt) yields identical output
......@@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
attention_config = {"backend": backend}
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model = resolve_model_name(backend)
......@@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
attention_config=attention_config,
)
# Baseline generation for the needle prompt alone.
......@@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
attention_config=attention_config,
)
mismatches = 0
......@@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
BACKENDS,
)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend, monkeypatch: pytest.MonkeyPatch
backend,
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
......@@ -193,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
)
# Use more realistic prompts for better token generation
......@@ -381,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
"backend",
BACKENDS,
)
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
def test_simple_generation(backend):
"""
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = resolve_model_name(backend)
llm = LLM(
......@@ -398,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
dtype="bfloat16",
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
)
prompt = "the capital of france is"
......@@ -444,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
......@@ -465,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_model_len=8192,
dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
......@@ -649,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch
backend,
):
"""
Test that verifies decode logprobs match prefill logprobs.
......@@ -664,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix.
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
......@@ -689,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_model_len=8192,
dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
)
# Use a few test prompts
......@@ -920,6 +919,7 @@ def LLM_with_max_seqs(
max_num_seqs: int,
gpu_memory_utilization: float,
max_model_len: int,
attention_config: dict | None = None,
) -> LLM:
"""
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
......@@ -934,6 +934,7 @@ def LLM_with_max_seqs(
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config=attention_config,
# Enable for MOE models
# enable_expert_parallel=True,
)
......@@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
@skip_unsupported
@pytest.mark.parametrize("backend", BACKENDS)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str, monkeypatch: pytest.MonkeyPatch
backend: str,
) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
# Override backend for this test (and the RemoteOpenAIServer child process).
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
......@@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
server_args: list[str] = [
"--max-model-len=8192",
"--max-num-seqs=32",
f"--attention-backend={backend}",
]
if tp_size:
server_args += ["-tp", tp_size]
......
......@@ -142,16 +142,17 @@ def run_tests(
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m:
# avoid precision errors
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
# Determine attention config based on platform
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
attention_config = {"backend": "TRITON_ATTN"}
else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
attention_config = {"backend": "ROCM_AITER_FA"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE)
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
......@@ -174,6 +175,7 @@ def run_tests(
spec_config,
test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
attention_config=attention_config,
)
outputs.append(test_results)
......@@ -262,6 +264,7 @@ def run_test(
spec_config: dict[str, Any] | None,
test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False,
attention_config: dict[str, Any] | None = None,
):
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
......@@ -301,6 +304,7 @@ def run_test(
dtype=dtype,
speculative_config=spec_config,
disable_log_stats=False,
attention_config=attention_config,
**cache_arg,
) as vllm_model:
results = []
......
......@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
@create_new_process_for_each_test()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
def test_cascade_attention(example_system_message, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER":
......@@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
"needs investigation. See issue #25679."
)
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
# No cascade attention.
single_prompt = [example_system_message + prompt]
responses = llm.generate(single_prompt, sampling_params)
ref_output = responses[0].outputs[0].text
# (Probably) Use cascade attention.
prompts = [example_system_message + prompt] * 64
responses = llm.generate(prompts, sampling_params)
for response in responses:
assert response.outputs[0].text == ref_output
llm = LLM(
model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
# No cascade attention.
single_prompt = [example_system_message + prompt]
responses = llm.generate(single_prompt, sampling_params)
ref_output = responses[0].outputs[0].text
# (Probably) Use cascade attention.
prompts = [example_system_message + prompt] * 64
responses = llm.generate(prompts, sampling_params)
for response in responses:
assert response.outputs[0].text == ref_output
......@@ -438,25 +438,26 @@ def test_eagle_correctness(
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
"""
with monkeypatch.context() as m:
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# pass if not ROCm
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
# Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# to Flex Attn
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
attention_config = None # Let it fall back to default
else:
attention_config = {"backend": attn_backend}
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
with monkeypatch.context() as m:
m.setenv("VLLM_MLA_DISABLE", "1")
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower():
......@@ -471,7 +472,10 @@ def test_eagle_correctness(
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size,
attention_config=attention_config,
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
......@@ -492,6 +496,7 @@ def test_eagle_correctness(
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl,
attention_config=attention_config,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
......
......@@ -3,21 +3,29 @@ set -xe
# Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
while [[ $# -gt 0 ]]; do
case $1 in
--kv_buffer_device)
KV_BUFFER_DEVICE="$2"
shift 2
;;
--attention-backend)
ATTENTION_BACKEND="$2"
shift 2
;;
*)
echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>] [--attention-backend <backend>]"
exit 1
;;
esac
done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
......@@ -148,6 +156,11 @@ run_tests_for_model() {
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
......@@ -188,7 +201,12 @@ run_tests_for_model() {
--block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# DP-EP attention mode
if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
......
......@@ -15,14 +15,14 @@ configs=(
run_tests() {
local label=$1
local extra_env=$2
local extra_args=$2
echo "=== Running tests (${label}) ==="
for cfg in "${configs[@]}"; do
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
echo "-> Running with ${cfg} ${extra_args:+and ${extra_args}}"
# Use 'env' to safely set variables without eval
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
if ! env ${cfg} bash "${SCRIPT}" ${extra_args}; then
echo "❌ Test failed for config: ${cfg} ${extra_args:+(${extra_args})}"
exit 1
fi
done
......@@ -34,8 +34,8 @@ run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty)
if [[ -n "${FLASHINFER:-}" ]]; then
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER"
run_tests "FLASHINFER backend" "--attention-backend FLASHINFER"
else
echo "FLASHINFER not set, skipping FLASHINFER runs."
fi
......@@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN",
],
)
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
def test_register_kv_caches(dist_init, attn_backend):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
......@@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
block layout info
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
vllm_config = create_vllm_config()
vllm_config = create_vllm_config(attention_backend=attn_backend)
# Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN":
......
......@@ -11,6 +11,7 @@ import torch
from vllm import SamplingParams
from vllm.config import (
AttentionConfig,
CacheConfig,
DeviceConfig,
KVTransferConfig,
......@@ -94,6 +95,7 @@ def create_vllm_config(
dtype: str = "float16",
cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None,
attention_backend: str | None = None,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
model_config = ModelConfig(
......@@ -124,12 +126,14 @@ def create_vllm_config(
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {},
)
attention_config = AttentionConfig(backend=attention_backend)
return VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
attention_config=attention_config,
)
......
......@@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN"]
......@@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic="test",
)
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
attention_config={"backend": attn_backend},
)
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
......
......@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
AttentionConfig,
CacheConfig,
DeviceConfig,
ModelConfig,
......@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def _create_proposer(
method: str,
num_speculative_tokens: int,
attention_backend: str | None = None,
speculative_token_tree: list[tuple[int, ...]] | None = None,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
......@@ -70,6 +72,7 @@ def _create_proposer(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
attention_config=AttentionConfig(backend=attention_backend),
)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
......@@ -331,8 +334,6 @@ def test_load_model(
use_distinct_lm_head,
monkeypatch,
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
......@@ -394,7 +395,9 @@ def test_load_model(
assert not isinstance(target_model, SupportsMultiModal)
# Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8)
proposer = _create_proposer(
method, num_speculative_tokens=8, attention_backend=attn_backend
)
# Call the method under test
proposer.load_model(target_model)
......@@ -420,8 +423,6 @@ def test_load_model(
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
......@@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens = [seq_len_1, seq_len_2]
# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens)
proposer = _create_proposer(
"eagle", num_speculative_tokens, attention_backend=attn_backend
)
# Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size
......@@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer(
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
"eagle",
num_speculative_tokens,
speculative_token_tree=spec_token_tree,
)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size
......
......@@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int):
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 80,
},
max_model_len=200,
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 80,
},
max_model_len=200,
attention_config={"backend": attn_backend},
)
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output is truncated due to max length"
)
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output "
"is truncated due to max length"
)
sampling_params = SamplingParams(
max_tokens=200,
structured_outputs=StructuredOutputsParams(
regex="^" + "a b c d e " * 15 + "$"
),
sampling_params = SamplingParams(
max_tokens=200,
structured_outputs=StructuredOutputsParams(regex="^" + "a b c d e " * 15 + "$"),
)
output = llm.generate(_PROMPTS, sampling_params)
for o in output:
assert o.prompt_token_ids is not None
assert (
len(o.prompt_token_ids)
< 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
<= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
output = llm.generate(_PROMPTS, sampling_params)
for o in output:
assert o.prompt_token_ids is not None
assert (
len(o.prompt_token_ids)
< 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
<= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
assert o.outputs[0].text == "a b c d e " * 15
assert o.outputs[0].text == "a b c d e " * 15
......@@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set --attention-config.backend=FLEX_ATTENTION to use "
"Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment