Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7eb6cb6c
Unverified
Commit
7eb6cb6c
authored
Dec 17, 2025
by
Matthew Bonanni
Committed by
GitHub
Dec 17, 2025
Browse files
[Attention] Update tests to remove deprecated env vars (#30563)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
9ca8cb38
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
167 additions
and
159 deletions
+167
-159
tests/v1/cudagraph/test_cudagraph_mode.py
tests/v1/cudagraph/test_cudagraph_mode.py
+7
-26
tests/v1/determinism/test_batch_invariance.py
tests/v1/determinism/test_batch_invariance.py
+13
-12
tests/v1/determinism/test_online_batch_invariance.py
tests/v1/determinism/test_online_batch_invariance.py
+2
-3
tests/v1/e2e/test_async_scheduling.py
tests/v1/e2e/test_async_scheduling.py
+13
-9
tests/v1/e2e/test_cascade_attention.py
tests/v1/e2e/test_cascade_attention.py
+16
-17
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+24
-19
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+20
-2
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
...nnector/nixl_integration/tp_config_sweep_accuracy_test.sh
+6
-6
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+2
-4
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+4
-0
tests/v1/kv_offload/test_cpu_offloading.py
tests/v1/kv_offload/test_cpu_offloading.py
+7
-8
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+12
-7
tests/v1/spec_decode/test_max_len.py
tests/v1/spec_decode/test_max_len.py
+40
-45
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+1
-1
No files found.
tests/v1/cudagraph/test_cudagraph_mode.py
View file @
7eb6cb6c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
os
import
weakref
import
weakref
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
...
@@ -13,26 +11,6 @@ from vllm import LLM
...
@@ -13,26 +11,6 @@ from vllm import LLM
from
vllm.config
import
CompilationConfig
,
CompilationMode
from
vllm.config
import
CompilationConfig
,
CompilationMode
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
@
contextlib
.
contextmanager
def
temporary_environ
(
env_vars
):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env
=
{
k
:
os
.
environ
.
get
(
k
)
for
k
in
env_vars
}
try
:
os
.
environ
.
update
(
env_vars
)
yield
finally
:
for
k
,
v
in
original_env
.
items
():
if
v
is
None
:
os
.
environ
.
pop
(
k
,
None
)
else
:
os
.
environ
[
k
]
=
v
# test attention backend and cudagraph_mode combo
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
# (backend_name, cudagraph_mode, supported)
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
...
@@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
...
@@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
):
):
pytest
.
skip
(
"Only Hopper GPUs support FA3 and FlashMLA"
)
pytest
.
skip
(
"Only Hopper GPUs support FA3 and FlashMLA"
)
env_vars
=
backend_config
s
[
backend_name
].
env_vars
attention_config
=
backend_config
.
attention_config
with
temporary_environ
(
env_vars
),
ExitStack
()
as
stack
:
with
ExitStack
()
as
stack
:
if
not
supported
:
if
not
supported
:
stack
.
enter_context
(
pytest
.
raises
(
Exception
))
stack
.
enter_context
(
pytest
.
raises
(
Exception
))
...
@@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
...
@@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
trust_remote_code
=
True
,
trust_remote_code
=
True
,
gpu_memory_utilization
=
0.45
,
gpu_memory_utilization
=
0.45
,
max_model_len
=
1024
,
max_model_len
=
1024
,
attention_config
=
attention_config
,
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
cudagraph_mode
=
cudagraph_mode
mode
=
CompilationMode
.
VLLM_COMPILE
,
cudagraph_mode
=
cudagraph_mode
),
),
...
@@ -122,9 +101,10 @@ combo_cases_2 = [
...
@@ -122,9 +101,10 @@ combo_cases_2 = [
def
test_cudagraph_compilation_combo
(
def
test_cudagraph_compilation_combo
(
backend_name
,
cudagraph_mode
,
compilation_mode
,
supported
backend_name
,
cudagraph_mode
,
compilation_mode
,
supported
):
):
env_vars
=
backend_configs
[
backend_name
].
env_vars
backend_config
=
backend_configs
[
backend_name
]
attention_config
=
backend_config
.
attention_config
with
temporary_environ
(
env_vars
),
ExitStack
()
as
stack
:
with
ExitStack
()
as
stack
:
if
not
supported
:
if
not
supported
:
stack
.
enter_context
(
pytest
.
raises
(
Exception
))
stack
.
enter_context
(
pytest
.
raises
(
Exception
))
...
@@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
...
@@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
trust_remote_code
=
True
,
trust_remote_code
=
True
,
gpu_memory_utilization
=
0.45
,
gpu_memory_utilization
=
0.45
,
max_model_len
=
1024
,
max_model_len
=
1024
,
attention_config
=
attention_config
,
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
mode
=
compilation_mode
,
cudagraph_mode
=
cudagraph_mode
mode
=
compilation_mode
,
cudagraph_mode
=
cudagraph_mode
),
),
...
...
tests/v1/determinism/test_batch_invariance.py
View file @
7eb6cb6c
...
@@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
...
@@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
BACKENDS
,
BACKENDS
,
)
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
(
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
backend
,
):
):
"""
"""
Ensures that the same request (the 'needle' prompt) yields identical output
Ensures that the same request (the 'needle' prompt) yields identical output
...
@@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
...
@@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
attention_config
=
{
"backend"
:
backend
}
# Allow overrides from environment (useful for CI tuning)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model
=
resolve_model_name
(
backend
)
model
=
resolve_model_name
(
backend
)
...
@@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
...
@@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs
=
max_batch_size
,
max_num_seqs
=
max_batch_size
,
gpu_memory_utilization
=
gpu_mem_util
,
gpu_memory_utilization
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
max_model_len
=
max_model_len
,
attention_config
=
attention_config
,
)
)
# Baseline generation for the needle prompt alone.
# Baseline generation for the needle prompt alone.
...
@@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
...
@@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs
=
max_batch_size
,
max_num_seqs
=
max_batch_size
,
gpu_memory_utilization
=
gpu_mem_util
,
gpu_memory_utilization
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
max_model_len
=
max_model_len
,
attention_config
=
attention_config
,
)
)
mismatches
=
0
mismatches
=
0
...
@@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
...
@@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
BACKENDS
,
BACKENDS
,
)
)
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
backend
,
):
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
resolve_model_name
(
backend
)
model_name
=
resolve_model_name
(
backend
)
...
@@ -193,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -193,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
dtype
=
"bfloat16"
,
# not everything is supported
dtype
=
"bfloat16"
,
# not everything is supported
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
0.9
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
attention_config
=
{
"backend"
:
backend
},
)
)
# Use more realistic prompts for better token generation
# Use more realistic prompts for better token generation
...
@@ -381,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -381,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
"backend"
,
"backend"
,
BACKENDS
,
BACKENDS
,
)
)
def
test_simple_generation
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_simple_generation
(
backend
):
"""
"""
Simple test that runs the model with a basic prompt and prints the output.
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
Useful for quick smoke testing and debugging.
"""
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
model
=
resolve_model_name
(
backend
)
model
=
resolve_model_name
(
backend
)
llm
=
LLM
(
llm
=
LLM
(
...
@@ -398,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
...
@@ -398,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
enable_prefix_caching
=
False
,
enable_prefix_caching
=
False
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
attention_config
=
{
"backend"
:
backend
},
)
)
prompt
=
"the capital of france is"
prompt
=
"the capital of france is"
...
@@ -444,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
...
@@ -444,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters).
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
# CRITICAL: Disable batch invariance for this test
# CRITICAL: Disable batch invariance for this test
monkeypatch
.
setenv
(
"VLLM_BATCH_INVARIANT"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_BATCH_INVARIANT"
,
"0"
)
monkeypatch
.
setattr
(
batch_invariant
,
"VLLM_BATCH_INVARIANT"
,
False
)
monkeypatch
.
setattr
(
batch_invariant
,
"VLLM_BATCH_INVARIANT"
,
False
)
...
@@ -465,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
...
@@ -465,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_model_len
=
8192
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
attention_config
=
{
"backend"
:
backend
},
)
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
...
@@ -649,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
...
@@ -649,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
def
test_decode_logprobs_match_prefill_logprobs
(
def
test_decode_logprobs_match_prefill_logprobs
(
backend
,
monkeypatch
:
pytest
.
MonkeyPatch
backend
,
):
):
"""
"""
Test that verifies decode logprobs match prefill logprobs.
Test that verifies decode logprobs match prefill logprobs.
...
@@ -664,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
...
@@ -664,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
This ensures that the logprobs from decode are consistent with what
This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix.
we would get if we ran prefill on each prefix.
"""
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
seed
=
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
))
random
.
seed
(
seed
)
random
.
seed
(
seed
)
model_name
=
resolve_model_name
(
backend
)
model_name
=
resolve_model_name
(
backend
)
...
@@ -689,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
...
@@ -689,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_model_len
=
8192
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
attention_config
=
{
"backend"
:
backend
},
)
)
# Use a few test prompts
# Use a few test prompts
...
@@ -920,6 +919,7 @@ def LLM_with_max_seqs(
...
@@ -920,6 +919,7 @@ def LLM_with_max_seqs(
max_num_seqs
:
int
,
max_num_seqs
:
int
,
gpu_memory_utilization
:
float
,
gpu_memory_utilization
:
float
,
max_model_len
:
int
,
max_model_len
:
int
,
attention_config
:
dict
|
None
=
None
,
)
->
LLM
:
)
->
LLM
:
"""
"""
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
...
@@ -934,6 +934,7 @@ def LLM_with_max_seqs(
...
@@ -934,6 +934,7 @@ def LLM_with_max_seqs(
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
enable_prefix_caching
=
False
,
enable_prefix_caching
=
False
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
attention_config
=
attention_config
,
# Enable for MOE models
# Enable for MOE models
# enable_expert_parallel=True,
# enable_expert_parallel=True,
)
)
tests/v1/determinism/test_online_batch_invariance.py
View file @
7eb6cb6c
...
@@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
...
@@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
@
skip_unsupported
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
BACKENDS
)
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
backend
:
str
,
)
->
None
:
)
->
None
:
random
.
seed
(
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
)))
random
.
seed
(
int
(
os
.
getenv
(
"VLLM_TEST_SEED"
,
"12345"
)))
# Override backend for this test (and the RemoteOpenAIServer child process).
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
backend
)
model_name
=
resolve_model_name
(
backend
)
model_name
=
resolve_model_name
(
backend
)
prompts_all
=
[
_random_prompt
(
10
,
50
)
for
_
in
range
(
32
)]
prompts_all
=
[
_random_prompt
(
10
,
50
)
for
_
in
range
(
32
)]
...
@@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
...
@@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
server_args
:
list
[
str
]
=
[
server_args
:
list
[
str
]
=
[
"--max-model-len=8192"
,
"--max-model-len=8192"
,
"--max-num-seqs=32"
,
"--max-num-seqs=32"
,
f
"--attention-backend=
{
backend
}
"
,
]
]
if
tp_size
:
if
tp_size
:
server_args
+=
[
"-tp"
,
tp_size
]
server_args
+=
[
"-tp"
,
tp_size
]
...
...
tests/v1/e2e/test_async_scheduling.py
View file @
7eb6cb6c
...
@@ -142,16 +142,17 @@ def run_tests(
...
@@ -142,16 +142,17 @@ def run_tests(
"""Test consistency of combos of async scheduling, preemption,
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""
uni/multiproc executor with spec decoding."""
with
monkeypatch
.
context
()
as
m
:
# Determine attention config based on platform
# avoid precision errors
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
if
is_testing_with_spec_decoding
:
if
is_testing_with_spec_decoding
:
# Use TRITON_ATTN for spec decoding test for consistency
# Use TRITON_ATTN for spec decoding test for consistency
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"TRITON_ATTN"
)
attention_config
=
{
"backend"
:
"TRITON_ATTN"
}
else
:
else
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"ROCM_AITER_FA"
)
attention_config
=
{
"backend"
:
"ROCM_AITER_FA"
}
else
:
else
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
attention_config
=
{
"backend"
:
"FLEX_ATTENTION"
}
with
monkeypatch
.
context
()
as
m
:
# lock matmul precision to full FP32 (IEEE)
# lock matmul precision to full FP32 (IEEE)
m
.
setenv
(
"VLLM_FLOAT32_MATMUL_PRECISION"
,
"ieee"
)
m
.
setenv
(
"VLLM_FLOAT32_MATMUL_PRECISION"
,
"ieee"
)
# m.setenv("VLLM_BATCH_INVARIANT", "1")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
...
@@ -174,6 +175,7 @@ def run_tests(
...
@@ -174,6 +175,7 @@ def run_tests(
spec_config
,
spec_config
,
test_prefill_chunking
=
test_prefill_chunking
,
test_prefill_chunking
=
test_prefill_chunking
,
is_testing_with_spec_decoding
=
is_testing_with_spec_decoding
,
is_testing_with_spec_decoding
=
is_testing_with_spec_decoding
,
attention_config
=
attention_config
,
)
)
outputs
.
append
(
test_results
)
outputs
.
append
(
test_results
)
...
@@ -262,6 +264,7 @@ def run_test(
...
@@ -262,6 +264,7 @@ def run_test(
spec_config
:
dict
[
str
,
Any
]
|
None
,
spec_config
:
dict
[
str
,
Any
]
|
None
,
test_prefill_chunking
:
bool
,
test_prefill_chunking
:
bool
,
is_testing_with_spec_decoding
:
bool
=
False
,
is_testing_with_spec_decoding
:
bool
=
False
,
attention_config
:
dict
[
str
,
Any
]
|
None
=
None
,
):
):
spec_decoding
=
spec_config
is
not
None
spec_decoding
=
spec_config
is
not
None
cache_arg
:
dict
[
str
,
Any
]
=
(
cache_arg
:
dict
[
str
,
Any
]
=
(
...
@@ -301,6 +304,7 @@ def run_test(
...
@@ -301,6 +304,7 @@ def run_test(
dtype
=
dtype
,
dtype
=
dtype
,
speculative_config
=
spec_config
,
speculative_config
=
spec_config
,
disable_log_stats
=
False
,
disable_log_stats
=
False
,
attention_config
=
attention_config
,
**
cache_arg
,
**
cache_arg
,
)
as
vllm_model
:
)
as
vllm_model
:
results
=
[]
results
=
[]
...
...
tests/v1/e2e/test_cascade_attention.py
View file @
7eb6cb6c
...
@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
...
@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
def
test_cascade_attention
(
example_system_message
,
monkeypatch
,
attn_backend
):
def
test_cascade_attention
(
example_system_message
,
attn_backend
):
prompt
=
"
\n
<User>: Implement fibonacci sequence in Python.
\n
<Claude>:"
prompt
=
"
\n
<User>: Implement fibonacci sequence in Python.
\n
<Claude>:"
if
attn_backend
==
"FLASHINFER"
:
if
attn_backend
==
"FLASHINFER"
:
...
@@ -19,10 +19,9 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
...
@@ -19,10 +19,9 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
"needs investigation. See issue #25679."
"needs investigation. See issue #25679."
)
)
with
monkeypatch
.
context
()
as
m
:
llm
=
LLM
(
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
model
=
"Qwen/Qwen2-1.5B-Instruct"
,
attention_config
=
{
"backend"
:
attn_backend
}
)
llm
=
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
# No cascade attention.
# No cascade attention.
...
...
tests/v1/e2e/test_spec_decode.py
View file @
7eb6cb6c
...
@@ -438,19 +438,17 @@ def test_eagle_correctness(
...
@@ -438,19 +438,17 @@ def test_eagle_correctness(
should be the same when using eagle speculative decoding.
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
model_setup: (method, model_name, eagle_model_name, tp_size)
"""
"""
with
monkeypatch
.
context
()
as
m
:
# Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# to Flex Attn
if
"Llama-4-Scout"
in
model_setup
[
1
]
and
attn_backend
==
"FLASH_ATTN"
:
if
"Llama-4-Scout"
in
model_setup
[
1
]
and
attn_backend
==
"FLASH_ATTN"
:
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# pass if not ROCm
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
# TODO: Enable Flex Attn for spec_decode on ROCm
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest
.
skip
(
"Flex Attn for spec_decode not supported on ROCm currently"
)
pytest
.
skip
(
"Flex Attn for spec_decode not supported on ROCm currently"
)
attention_config
=
None
# Let it fall back to default
else
:
else
:
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
attention_config
=
{
"backend"
:
attn_backend
}
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
pytest
.
skip
(
...
@@ -458,6 +456,9 @@ def test_eagle_correctness(
...
@@ -458,6 +456,9 @@ def test_eagle_correctness(
"multi-token eagle spec decode on current platform"
"multi-token eagle spec decode on current platform"
)
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
if
attn_backend
==
"ROCM_AITER_FA"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"ROCM_AITER_FA"
and
current_platform
.
is_rocm
():
if
"deepseek"
in
model_setup
[
1
].
lower
():
if
"deepseek"
in
model_setup
[
1
].
lower
():
pytest
.
skip
(
"ROCM_AITER_FA for deepseek not supported on ROCm platform"
)
pytest
.
skip
(
"ROCM_AITER_FA for deepseek not supported on ROCm platform"
)
...
@@ -471,7 +472,10 @@ def test_eagle_correctness(
...
@@ -471,7 +472,10 @@ def test_eagle_correctness(
max_num_batched_tokens
=
128
if
enable_chunked_prefill
else
max_model_len
max_num_batched_tokens
=
128
if
enable_chunked_prefill
else
max_model_len
ref_llm
=
LLM
(
ref_llm
=
LLM
(
model
=
model_name
,
max_model_len
=
max_model_len
,
tensor_parallel_size
=
tp_size
model
=
model_name
,
max_model_len
=
max_model_len
,
tensor_parallel_size
=
tp_size
,
attention_config
=
attention_config
,
)
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
ref_outputs
=
ref_llm
.
chat
(
test_prompts
,
sampling_config
)
del
ref_llm
del
ref_llm
...
@@ -492,6 +496,7 @@ def test_eagle_correctness(
...
@@ -492,6 +496,7 @@ def test_eagle_correctness(
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
enable_chunked_prefill
=
enable_chunked_prefill
,
enable_chunked_prefill
=
enable_chunked_prefill
,
model_impl
=
model_impl
,
model_impl
=
model_impl
,
attention_config
=
attention_config
,
)
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
spec_outputs
=
spec_llm
.
chat
(
test_prompts
,
sampling_config
)
matches
=
0
matches
=
0
...
...
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
7eb6cb6c
...
@@ -3,21 +3,29 @@ set -xe
...
@@ -3,21 +3,29 @@ set -xe
# Parse command line arguments
# Parse command line arguments
KV_BUFFER_DEVICE
=
"cuda"
# Default to cuda
KV_BUFFER_DEVICE
=
"cuda"
# Default to cuda
ATTENTION_BACKEND
=
""
# Default to empty (use vllm default)
while
[[
$#
-gt
0
]]
;
do
while
[[
$#
-gt
0
]]
;
do
case
$1
in
case
$1
in
--kv_buffer_device
)
--kv_buffer_device
)
KV_BUFFER_DEVICE
=
"
$2
"
KV_BUFFER_DEVICE
=
"
$2
"
shift
2
shift
2
;;
;;
--attention-backend
)
ATTENTION_BACKEND
=
"
$2
"
shift
2
;;
*
)
*
)
echo
"Unknown option
$1
"
echo
"Unknown option
$1
"
echo
"Usage:
$0
[--kv_buffer_device <cuda|cpu>]"
echo
"Usage:
$0
[--kv_buffer_device <cuda|cpu>]
[--attention-backend <backend>]
"
exit
1
exit
1
;;
;;
esac
esac
done
done
echo
"Running accuracy tests with kv_buffer_device=
$KV_BUFFER_DEVICE
"
echo
"Running accuracy tests with kv_buffer_device=
$KV_BUFFER_DEVICE
"
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
echo
"Using attention backend:
$ATTENTION_BACKEND
"
fi
DECODER_KV_LAYOUT
=
${
DECODER_KV_LAYOUT
:-
"HND"
}
# Default to HND, optional NHD
DECODER_KV_LAYOUT
=
${
DECODER_KV_LAYOUT
:-
"HND"
}
# Default to HND, optional NHD
if
[[
"
$DECODER_KV_LAYOUT
"
==
"NHD"
]]
;
then
if
[[
"
$DECODER_KV_LAYOUT
"
==
"NHD"
]]
;
then
...
@@ -148,6 +156,11 @@ run_tests_for_model() {
...
@@ -148,6 +156,11 @@ run_tests_for_model() {
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--kv-transfer-config '
$KV_CONFIG
'"
--kv-transfer-config '
$KV_CONFIG
'"
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
if
[
-n
"
$model_args
"
]
;
then
if
[
-n
"
$model_args
"
]
;
then
FULL_CMD
=
"
$BASE_CMD
$model_args
"
FULL_CMD
=
"
$BASE_CMD
$model_args
"
else
else
...
@@ -189,6 +202,11 @@ run_tests_for_model() {
...
@@ -189,6 +202,11 @@ run_tests_for_model() {
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--gpu-memory-utilization
$GPU_MEMORY_UTILIZATION
\
--kv-transfer-config '
$KV_CONFIG
'"
--kv-transfer-config '
$KV_CONFIG
'"
# Add attention backend config if specified
if
[[
-n
"
$ATTENTION_BACKEND
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--attention-backend=
$ATTENTION_BACKEND
"
fi
# DP-EP attention mode
# DP-EP attention mode
if
[[
-z
"
$DP_EP
"
]]
;
then
if
[[
-z
"
$DP_EP
"
]]
;
then
BASE_CMD
=
"
${
BASE_CMD
}
--tensor-parallel-size
$DECODER_TP_SIZE
"
BASE_CMD
=
"
${
BASE_CMD
}
--tensor-parallel-size
$DECODER_TP_SIZE
"
...
...
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
View file @
7eb6cb6c
...
@@ -15,14 +15,14 @@ configs=(
...
@@ -15,14 +15,14 @@ configs=(
run_tests
()
{
run_tests
()
{
local
label
=
$1
local
label
=
$1
local
extra_
env
=
$2
local
extra_
args
=
$2
echo
"=== Running tests (
${
label
}
) ==="
echo
"=== Running tests (
${
label
}
) ==="
for
cfg
in
"
${
configs
[@]
}
"
;
do
for
cfg
in
"
${
configs
[@]
}
"
;
do
echo
"-> Running with
${
cfg
}
${
extra_
env
:+and
${
extra_
env
}}
"
echo
"-> Running with
${
cfg
}
${
extra_
args
:+and
${
extra_
args
}}
"
# Use 'env' to safely set variables without eval
# Use 'env' to safely set variables without eval
if
!
env
${
extra_env
}
${
cfg
}
bash
"
${
SCRIPT
}
"
;
then
if
!
env
${
cfg
}
bash
"
${
SCRIPT
}
"
${
extra_args
}
;
then
echo
"❌ Test failed for config:
${
cfg
}
${
extra_
env
:+
(
${
extra_
env
}
)
}
"
echo
"❌ Test failed for config:
${
cfg
}
${
extra_
args
:+
(
${
extra_
args
}
)
}
"
exit
1
exit
1
fi
fi
done
done
...
@@ -34,8 +34,8 @@ run_tests "default backend" ""
...
@@ -34,8 +34,8 @@ run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty)
# Check if FLASHINFER is set (non-empty)
if
[[
-n
"
${
FLASHINFER
:-}
"
]]
;
then
if
[[
-n
"
${
FLASHINFER
:-}
"
]]
;
then
echo
"FLASHINFER is set, rerunning with
VLLM_ATTENTION_BACKEND=
FLASHINFER"
echo
"FLASHINFER is set, rerunning with
--attention-backend
FLASHINFER"
run_tests
"FLASHINFER backend"
"
VLLM_ATTENTION_BACKEND=
FLASHINFER"
run_tests
"FLASHINFER backend"
"
--attention-backend
FLASHINFER"
else
else
echo
"FLASHINFER not set, skipping FLASHINFER runs."
echo
"FLASHINFER not set, skipping FLASHINFER runs."
fi
fi
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
7eb6cb6c
...
@@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
...
@@ -1132,7 +1132,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN"
,
"TRITON_ATTN"
,
],
],
)
)
def
test_register_kv_caches
(
dist_init
,
attn_backend
,
monkeypatch
):
def
test_register_kv_caches
(
dist_init
,
attn_backend
):
"""
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
correct data.
...
@@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
...
@@ -1144,9 +1144,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
block layout info
block layout info
"""
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
vllm_config
=
create_vllm_config
(
attention_backend
=
attn_backend
)
vllm_config
=
create_vllm_config
()
# Import the appropriate backend based on the parameter
# Import the appropriate backend based on the parameter
if
attn_backend
==
"FLASH_ATTN"
:
if
attn_backend
==
"FLASH_ATTN"
:
...
...
tests/v1/kv_connector/unit/utils.py
View file @
7eb6cb6c
...
@@ -11,6 +11,7 @@ import torch
...
@@ -11,6 +11,7 @@ import torch
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.config
import
(
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
CacheConfig
,
DeviceConfig
,
DeviceConfig
,
KVTransferConfig
,
KVTransferConfig
,
...
@@ -94,6 +95,7 @@ def create_vllm_config(
...
@@ -94,6 +95,7 @@ def create_vllm_config(
dtype
:
str
=
"float16"
,
dtype
:
str
=
"float16"
,
cache_dtype
:
str
=
"auto"
,
cache_dtype
:
str
=
"auto"
,
hf_overrides
:
dict
[
str
,
Any
]
|
None
=
None
,
hf_overrides
:
dict
[
str
,
Any
]
|
None
=
None
,
attention_backend
:
str
|
None
=
None
,
)
->
VllmConfig
:
)
->
VllmConfig
:
"""Initialize VllmConfig For Testing."""
"""Initialize VllmConfig For Testing."""
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
...
@@ -124,12 +126,14 @@ def create_vllm_config(
...
@@ -124,12 +126,14 @@ def create_vllm_config(
enable_permute_local_kv
=
enable_permute_local_kv
,
enable_permute_local_kv
=
enable_permute_local_kv
,
kv_connector_extra_config
=
kv_connector_extra_config
or
{},
kv_connector_extra_config
=
kv_connector_extra_config
or
{},
)
)
attention_config
=
AttentionConfig
(
backend
=
attention_backend
)
return
VllmConfig
(
return
VllmConfig
(
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
model_config
=
model_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
kv_transfer_config
=
kv_transfer_config
,
kv_transfer_config
=
kv_transfer_config
,
device_config
=
DeviceConfig
(
"cpu"
),
device_config
=
DeviceConfig
(
"cpu"
),
attention_config
=
attention_config
,
)
)
...
...
tests/v1/kv_offload/test_cpu_offloading.py
View file @
7eb6cb6c
...
@@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
...
@@ -13,7 +13,6 @@ from vllm import LLM, SamplingParams, TokensPrompt
from
vllm.config
import
KVEventsConfig
,
KVTransferConfig
from
vllm.config
import
KVEventsConfig
,
KVTransferConfig
from
vllm.distributed.kv_events
import
BlockStored
,
KVEventBatch
from
vllm.distributed.kv_events
import
BlockStored
,
KVEventBatch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.system_utils
import
set_env_var
CPU_BLOCK_SIZES
=
[
48
]
CPU_BLOCK_SIZES
=
[
48
]
ATTN_BACKENDS
=
[
"FLASH_ATTN"
]
ATTN_BACKENDS
=
[
"FLASH_ATTN"
]
...
@@ -180,12 +179,12 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
...
@@ -180,12 +179,12 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic
=
"test"
,
topic
=
"test"
,
)
)
with
set_env_var
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
):
llm
=
LLM
(
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
gpu_memory_utilization
=
0.5
,
gpu_memory_utilization
=
0.5
,
kv_events_config
=
kv_events_config
,
kv_events_config
=
kv_events_config
,
kv_transfer_config
=
kv_transfer_config
,
kv_transfer_config
=
kv_transfer_config
,
attention_config
=
{
"backend"
:
attn_backend
},
)
)
events_endpoint
=
events_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
events_endpoint
=
events_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
...
...
tests/v1/spec_decode/test_eagle.py
View file @
7eb6cb6c
...
@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
...
@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
)
)
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config
import
(
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
CacheConfig
,
DeviceConfig
,
DeviceConfig
,
ModelConfig
,
ModelConfig
,
...
@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
...
@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def
_create_proposer
(
def
_create_proposer
(
method
:
str
,
method
:
str
,
num_speculative_tokens
:
int
,
num_speculative_tokens
:
int
,
attention_backend
:
str
|
None
=
None
,
speculative_token_tree
:
list
[
tuple
[
int
,
...]]
|
None
=
None
,
speculative_token_tree
:
list
[
tuple
[
int
,
...]]
|
None
=
None
,
)
->
EagleProposer
:
)
->
EagleProposer
:
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
...
@@ -70,6 +72,7 @@ def _create_proposer(
...
@@ -70,6 +72,7 @@ def _create_proposer(
max_model_len
=
model_config
.
max_model_len
,
max_model_len
=
model_config
.
max_model_len
,
is_encoder_decoder
=
model_config
.
is_encoder_decoder
,
is_encoder_decoder
=
model_config
.
is_encoder_decoder
,
),
),
attention_config
=
AttentionConfig
(
backend
=
attention_backend
),
)
)
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
current_platform
.
device_type
)
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
current_platform
.
device_type
)
...
@@ -331,8 +334,6 @@ def test_load_model(
...
@@ -331,8 +334,6 @@ def test_load_model(
use_distinct_lm_head
,
use_distinct_lm_head
,
monkeypatch
,
monkeypatch
,
):
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
pytest
.
skip
(
"TRITON_ATTN does not support "
"TRITON_ATTN does not support "
...
@@ -394,7 +395,9 @@ def test_load_model(
...
@@ -394,7 +395,9 @@ def test_load_model(
assert
not
isinstance
(
target_model
,
SupportsMultiModal
)
assert
not
isinstance
(
target_model
,
SupportsMultiModal
)
# Create proposer using the helper function
# Create proposer using the helper function
proposer
=
_create_proposer
(
method
,
num_speculative_tokens
=
8
)
proposer
=
_create_proposer
(
method
,
num_speculative_tokens
=
8
,
attention_backend
=
attn_backend
)
# Call the method under test
# Call the method under test
proposer
.
load_model
(
target_model
)
proposer
.
load_model
(
target_model
)
...
@@ -420,8 +423,6 @@ def test_load_model(
...
@@ -420,8 +423,6 @@ def test_load_model(
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
8
])
def
test_propose
(
method
,
attn_backend
,
num_speculative_tokens
,
monkeypatch
):
def
test_propose
(
method
,
attn_backend
,
num_speculative_tokens
,
monkeypatch
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
pytest
.
skip
(
"TRITON_ATTN does not support "
"TRITON_ATTN does not support "
...
@@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
...
@@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens
=
[
seq_len_1
,
seq_len_2
]
seq_lens
=
[
seq_len_1
,
seq_len_2
]
# Create proposer first so we can use its actual hidden_size
# Create proposer first so we can use its actual hidden_size
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
)
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
,
attention_backend
=
attn_backend
)
# Get the hidden_size from the proposer to ensure consistency
# Get the hidden_size from the proposer to ensure consistency
hidden_size
=
proposer
.
hidden_size
hidden_size
=
proposer
.
hidden_size
...
@@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree):
...
@@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size.
# Create proposer first so we can use its actual hidden_size.
proposer
=
_create_proposer
(
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree
"eagle"
,
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree
,
)
)
# Get the hidden_size from the proposer to ensure consistency.
# Get the hidden_size from the proposer to ensure consistency.
hidden_size
=
proposer
.
hidden_size
hidden_size
=
proposer
.
hidden_size
...
...
tests/v1/spec_decode/test_max_len.py
View file @
7eb6cb6c
...
@@ -38,9 +38,6 @@ def test_ngram_max_len(num_speculative_tokens: int):
...
@@ -38,9 +38,6 @@ def test_ngram_max_len(num_speculative_tokens: int):
def
test_eagle_max_len
(
def
test_eagle_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
attn_backend
:
str
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
attn_backend
:
str
):
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
pytest
.
skip
(
"TRITON_ATTN does not support "
"TRITON_ATTN does not support "
...
@@ -48,7 +45,7 @@ def test_eagle_max_len(
...
@@ -48,7 +45,7 @@ def test_eagle_max_len(
)
)
if
attn_backend
==
"ROCM_AITER_FA"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"ROCM_AITER_FA"
and
current_platform
.
is_rocm
():
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
llm
=
LLM
(
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
...
@@ -60,20 +57,18 @@ def test_eagle_max_len(
...
@@ -60,20 +57,18 @@ def test_eagle_max_len(
"max_model_len"
:
80
,
"max_model_len"
:
80
,
},
},
max_model_len
=
200
,
max_model_len
=
200
,
attention_config
=
{
"backend"
:
attn_backend
},
)
)
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
ignore_eos
=
True
)
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
ignore_eos
=
True
)
outputs
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
outputs
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
outputs
:
for
o
in
outputs
:
assert
o
.
outputs
[
0
].
finish_reason
==
"length"
,
(
assert
o
.
outputs
[
0
].
finish_reason
==
"length"
,
(
"This test is only meaningful if the output "
"This test is only meaningful if the output is truncated due to max length"
"is truncated due to max length"
)
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
max_tokens
=
200
,
structured_outputs
=
StructuredOutputsParams
(
structured_outputs
=
StructuredOutputsParams
(
regex
=
"^"
+
"a b c d e "
*
15
+
"$"
),
regex
=
"^"
+
"a b c d e "
*
15
+
"$"
),
)
)
output
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
output
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
output
:
for
o
in
output
:
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
7eb6cb6c
...
@@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
...
@@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend):
raise
ValueError
(
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
f
"Supported head sizes are:
{
cls
.
get_supported_head_sizes
()
}
. "
f
"Supported head sizes are:
{
cls
.
get_supported_head_sizes
()
}
. "
"Set --attention-
config.
backend=FLEX_ATTENTION to use "
"Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
"FlexAttention backend which supports all head sizes."
)
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment