Unverified Commit 2a68464c authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Test] `test_async_scheduling.py` improvements (#36340)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent bdd8981d
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from itertools import repeat from itertools import repeat
from typing import Any from typing import Any
...@@ -19,6 +20,8 @@ from ...models.utils import check_outputs_equal ...@@ -19,6 +20,8 @@ from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B" MODEL = "Qwen/Qwen3-0.6B"
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct" MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
# Need to enforce eager for MRV2 while we sort out cudagraph issues.
ENFORCE_EAGER = os.getenv("ENFORCE_EAGER", "0") == "1"
first_prompt = ( first_prompt = (
"The following numbers of the sequence " "The following numbers of the sequence "
...@@ -47,10 +50,10 @@ def test_without_spec_decoding( ...@@ -47,10 +50,10 @@ def test_without_spec_decoding(
test_sampling_params: list[dict[str, Any]] = [ test_sampling_params: list[dict[str, Any]] = [
dict(), dict(),
# dict(min_tokens=20), # dict(min_tokens=20),
dict(presence_penalty=-1.0), dict(frequency_penalty=-1.0),
dict(bad_words=["the", " the"]), dict(bad_words=["the", " the"]),
dict(logprobs=2), dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0), dict(logprobs=2, frequency_penalty=-1.0),
dict(structured_outputs=struct_outputs), dict(structured_outputs=struct_outputs),
dict( dict(
structured_outputs=struct_outputs, structured_outputs=struct_outputs,
...@@ -58,12 +61,12 @@ def test_without_spec_decoding( ...@@ -58,12 +61,12 @@ def test_without_spec_decoding(
), ),
dict( dict(
structured_outputs=struct_outputs, structured_outputs=struct_outputs,
presence_penalty=-1.0, frequency_penalty=-1.0,
), ),
dict( dict(
structured_outputs=struct_outputs, structured_outputs=struct_outputs,
logprobs=2, logprobs=2,
presence_penalty=-1.0, frequency_penalty=-1.0,
), ),
] ]
...@@ -116,15 +119,15 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke ...@@ -116,15 +119,15 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke
test_sampling_params = [ test_sampling_params = [
dict(), dict(),
dict(presence_penalty=-1.0), dict(frequency_penalty=-1.0),
dict(bad_words=["the", " the"]), dict(bad_words=["the", " the"]),
dict(logprobs=2), dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0), dict(logprobs=2, frequency_penalty=-1.0),
dict(structured_outputs=struct_outputs), dict(structured_outputs=struct_outputs),
dict( dict(
structured_outputs=struct_outputs, structured_outputs=struct_outputs,
logprobs=2, logprobs=2,
presence_penalty=-1.0, frequency_penalty=-1.0,
), ),
] ]
...@@ -144,14 +147,7 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke ...@@ -144,14 +147,7 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke
(True, "uni", True, spec_config_short, True), (True, "uni", True, spec_config_short, True),
] ]
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
run_tests(
monkeypatch,
MTP_MODEL,
test_configs,
test_sampling_params,
is_testing_with_spec_decoding=True,
)
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch): def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
...@@ -196,12 +192,11 @@ def run_tests( ...@@ -196,12 +192,11 @@ def run_tests(
model: str, model: str,
test_configs: list[tuple], test_configs: list[tuple],
test_sampling_params: list[dict[str, Any]], test_sampling_params: list[dict[str, Any]],
is_testing_with_spec_decoding: bool = False,
): ):
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding.""" uni/multiproc executor with spec decoding."""
# Determine attention config based on platform # Flex attention supports float32.
attention_config = {"backend": "FLEX_ATTENTION"} attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m: with monkeypatch.context() as m:
...@@ -226,7 +221,6 @@ def run_tests( ...@@ -226,7 +221,6 @@ def run_tests(
async_scheduling, async_scheduling,
spec_config, spec_config,
test_prefill_chunking=test_prefill_chunking, test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
attention_config=attention_config, attention_config=attention_config,
) )
outputs.append(test_results) outputs.append(test_results)
...@@ -250,6 +244,7 @@ def run_tests( ...@@ -250,6 +244,7 @@ def run_tests(
test_acceptance_rates or repeat(None), test_acceptance_rates or repeat(None),
test_sampling_params, test_sampling_params,
): ):
reason = None
try: try:
check_outputs_equal( check_outputs_equal(
outputs_0_lst=base_outs, outputs_0_lst=base_outs,
...@@ -257,42 +252,57 @@ def run_tests( ...@@ -257,42 +252,57 @@ def run_tests(
name_0=f"baseline=[{baseline_config}], params={params}", name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}", name_1=f"config=[{test_config}], params={params}",
) )
except AssertionError as e:
assert _all_logprobs_match(base_logprobs, test_logprobs) reason = "outputs ", e
if ( if reason is None:
base_acceptance_rate is not None try:
and test_acceptance_rate is not None assert _all_logprobs_match(base_logprobs, test_logprobs)
): except AssertionError as e:
if "spec_mml=None" in test_config: reason = "logprobs", e
# Preemption causes more variance in acceptance rates
if ( if reason is None:
current_platform.is_rocm() try:
and "preemption=True" in test_config if (
): base_acceptance_rate is not None
tolerance = 0.10 and test_acceptance_rate is not None
):
if "spec_mml=None" in test_config:
# Preemption causes more variance in acceptance rates
if (
current_platform.is_rocm()
and "preemption=True" in test_config
):
tolerance = 0.10
else:
tolerance = 0.05
assert (
test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate
== pytest.approx(base_acceptance_rate, rel=tolerance)
)
else: else:
tolerance = 0.05 # Currently the reported acceptance rate is expected to be
assert ( # lower when we sometimes skip drafting altogether.
test_acceptance_rate > base_acceptance_rate assert test_acceptance_rate > 0.1
or test_acceptance_rate except AssertionError as e:
== pytest.approx(base_acceptance_rate, rel=tolerance) reason = "accept ", e
)
else: if reason is None:
# Currently the reported acceptance rate is expected to be
# lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.1
print( print(
f"PASSED: config=[{test_config}], params={params}" f"\033[32mPASSED\033[0m: "
f"config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}" f" accept_rate={test_acceptance_rate}"
) )
except AssertionError as e: else:
reason_str, _ = reason
print( print(
f"FAILED: config=[{test_config}], params={params}" f"\033[31mFAILED\033[0m({reason_str}): "
f"config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}" f" accept_rate={test_acceptance_rate}"
) )
if failure is None: if failure is None:
failure = e _, failure = reason
if failure is not None: if failure is not None:
raise failure raise failure
...@@ -307,7 +317,6 @@ def run_test( ...@@ -307,7 +317,6 @@ def run_test(
async_scheduling: bool, async_scheduling: bool,
spec_config: dict[str, Any] | None, spec_config: dict[str, Any] | None,
test_prefill_chunking: bool, test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False,
attention_config: dict[str, Any] | None = None, attention_config: dict[str, Any] | None = None,
): ):
spec_decoding = spec_config is not None spec_decoding = spec_config is not None
...@@ -335,7 +344,7 @@ def run_test( ...@@ -335,7 +344,7 @@ def run_test(
enable_chunked_prefill=test_prefill_chunking, enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking # Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None, max_num_batched_tokens=48 if test_prefill_chunking else None,
# enforce_eager=True, enforce_eager=ENFORCE_EAGER,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
distributed_executor_backend=executor, distributed_executor_backend=executor,
dtype="float32", dtype="float32",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment