test_async_scheduling.py 5.83 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

import pytest
6
import torch._dynamo.config as dynamo_config
7
8

from vllm import SamplingParams
9
from vllm.logprobs import Logprob
10
from vllm.sampling_params import StructuredOutputsParams
11
12
13
14
15
16
17

from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal

MODEL = "Qwen/Qwen3-0.6B"


18
@dynamo_config.patch(cache_size_limit=16)
19
20
21
def test_preempt_and_async_scheduling_e2e(
    sample_json_schema, monkeypatch: pytest.MonkeyPatch
):
22
    """Test consistency of combos of async scheduling, preemption,
23
24
    uni/multiproc executor, and various sampling parameters
    including structured outputs."""
25
26
27
28
29
30
31
32
33
34
35
36
37

    first_prompt = (
        "The following numbers of the sequence "
        + ", ".join(str(i) for i in range(10))
        + " are:"
    )
    example_prompts = [first_prompt, "In one word, the capital of France is "] + [
        f"Tell me about the number {i}: " for i in range(32)
    ]

    sampling_param_tests: list[dict[str, Any]] = [
        dict(),
        # dict(min_tokens=20),
38
39
        dict(presence_penalty=-1.0),
        dict(bad_words=["the", " the"]),
40
41
        dict(logprobs=2),
        dict(logprobs=2, presence_penalty=-1.0),
42
43
44
45
46
47
        dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
        dict(
            structured_outputs=StructuredOutputsParams(json=sample_json_schema),
            logprobs=2,
            presence_penalty=-1.0,
        ),
48
49
50
51
52
53
54
55
56
    ]

    default_params = dict(
        temperature=0.0,  # greedy
        max_tokens=20,
    )

    with monkeypatch.context() as m:
        m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
57
        # m.setenv("VLLM_BATCH_INVARIANT", "1")
58

59
        outputs: list[tuple[str, list]] = []
60
        for test_preemption in [False, True]:
61
            for executor in ["mp", "uni"]:
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                for async_scheduling in [False, True]:
                    cache_arg: dict[str, Any] = (
                        dict(num_gpu_blocks_override=32)
                        if test_preemption
                        else dict(gpu_memory_utilization=0.7)
                    )
                    test_config = (
                        f"executor={executor}, preemption={test_preemption},"
                        f" async_sched={async_scheduling}"
                    )
                    print("-" * 80)
                    print(f"---- TESTING: {test_config}")
                    print("-" * 80)
                    with VllmRunner(
                        MODEL,
                        max_model_len=512,
                        enforce_eager=True,
                        async_scheduling=async_scheduling,
                        distributed_executor_backend=executor,
                        dtype="float32",  # avoid precision errors
                        **cache_arg,
                    ) as vllm_model:
                        results = []
                        for override_params in sampling_param_tests:
                            print(f"----------- RUNNING PARAMS: {override_params}")
                            results.append(
                                vllm_model.generate(
                                    example_prompts,
                                    sampling_params=SamplingParams(
                                        **default_params, **override_params
                                    ),
93
                                    return_logprobs=True,
94
95
                                )
                            )
96
97
98
99

                        if not outputs:
                            # First check that the different parameter configs
                            # actually result in different output.
100
                            for (other_test_outs, other_test_logprobs), params in zip(
101
102
103
104
                                results[1:], sampling_param_tests[1:]
                            ):
                                with pytest.raises(AssertionError):
                                    check_outputs_equal(
105
106
                                        outputs_0_lst=results[0][0],
                                        outputs_1_lst=other_test_outs,
107
108
109
                                        name_0=f"baseline params={params}",
                                        name_1=f"other params={params}",
                                    )
110
111
112
                                    assert _all_logprobs_match(
                                        results[0][1], other_test_logprobs
                                    )
113

114
115
116
117
118
                        outputs.append((test_config, results))

    baseline_config, baseline_tests = outputs[0]

    for test_config, test_outputs in outputs[1:]:
119
        for (base_outs, base_logprobs), (test_outs, test_logprobs), params in zip(
120
121
122
123
124
125
126
127
            baseline_tests, test_outputs, sampling_param_tests
        ):
            check_outputs_equal(
                outputs_0_lst=base_outs,
                outputs_1_lst=test_outs,
                name_0=f"baseline=[{baseline_config}], params={params}",
                name_1=f"config=[{test_config}], params={params}",
            )
128
            assert _all_logprobs_match(base_logprobs, test_logprobs)
129
130

            print(f"PASSED: config=[{test_config}], params={params}")
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151


def _all_logprobs_match(req_a, req_b) -> bool:
    return (
        req_a == req_b
        or len(req_a) == len(req_b)
        and all(
            len(seq_a) == len(seq_b)
            and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
            for seq_a, seq_b in zip(req_a, req_b)
        )
    )


def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
    return len(lps_a) == len(lps_b) and all(
        a.decoded_token == b.decoded_token
        and a.rank == b.rank
        and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
        for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
    )