"tests/vscode:/vscode.git/clone" did not exist on "d565e0976fb5ffd353727066ac8aa98272e318af"
test_async_scheduling.py 13.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from itertools import repeat
4
5
6
from typing import Any

import pytest
7
import torch._dynamo.config as dynamo_config
8
9

from vllm import SamplingParams
10
from vllm.logprobs import Logprob
11
from vllm.platforms import current_platform
12
from vllm.sampling_params import StructuredOutputsParams
13
from vllm.v1.metrics.reader import Metric
14
15
16
17
18

from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal

MODEL = "Qwen/Qwen3-0.6B"
19
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
20
21


22
23
24
25
26
27
28
29
first_prompt = (
    "The following numbers of the sequence "
    + ", ".join(str(i) for i in range(10))
    + " are:"
)
example_prompts = [first_prompt, "In one word, the capital of France is "] + [
    f"Tell me about the number {i}: " for i in range(32)
]
30

31
32
default_params = dict(
    temperature=0.0,  # greedy
33
34
    max_tokens=23,
    min_tokens=18,
35
)
36

37
38
39
40
41
42
43
44
45

def test_without_spec_decoding(
    sample_json_schema,
    monkeypatch: pytest.MonkeyPatch,
):
    """Test consistency of combos of async scheduling, preemption,
    uni/multiproc executor, prefill chunking."""
    struct_outputs = StructuredOutputsParams(json=sample_json_schema)
    test_sampling_params: list[dict[str, Any]] = [
46
47
        dict(),
        # dict(min_tokens=20),
48
49
        dict(presence_penalty=-1.0),
        dict(bad_words=["the", " the"]),
50
51
        dict(logprobs=2),
        dict(logprobs=2, presence_penalty=-1.0),
52
        dict(structured_outputs=struct_outputs),
53
        dict(
54
            structured_outputs=struct_outputs,
55
56
57
            logprobs=2,
            presence_penalty=-1.0,
        ),
58
59
    ]

60
61
62
63
64
65
66
67
68
69
    # test_preemption, executor, async_scheduling,
    # spec_config, test_prefill_chunking
    test_configs = [
        (False, "mp", False, None, False),
        (True, "mp", False, None, True),
        (False, "mp", True, None, False),
        (False, "uni", True, None, False),
        (True, "mp", True, None, False),
        (True, "uni", True, None, False),
        (False, "mp", True, None, True),
70
71
        (True, "mp", True, None, True),
        (True, "uni", True, None, True),
72
73
    ]

74
75
76
77
78
79
80
81
82
83
84
85
    if current_platform.is_rocm():
        # On ROCm, Only test with structured_outputs (deterministic)
        # and skip chunk_prefill (more variable).
        test_configs = [
            cfg
            for cfg in test_configs
            if not cfg[4]  # skip chunk_prefill=True
        ]
        test_sampling_params = [
            p for p in test_sampling_params if p.get("structured_outputs") is not None
        ]

86
    run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
87

88
89
90
91
92
93
94
95

def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
    """Test consistency and acceptance rates with some different combos of
    preemption, executor, async scheduling, prefill chunking,
    spec decoding model length.
    """

    spec_config = {
96
        "method": "eagle3",
97
        "num_speculative_tokens": 2,
98
        "model": "nm-testing/Llama3_2_1B_speculator.eagle3",
99
    }
100
    # Set small draft model len to force doesn't-fit-in-drafter case.
101
102
    spec_config_short = spec_config | {"max_model_len": 50}

103
104
105
106
107
    test_sampling_params = [
        dict(),
        dict(logprobs=2),
    ]

108
109
110
111
112
113
114
115
116
117
118
119
    # test_preemption, executor, async_scheduling,
    # spec_config, test_prefill_chunking
    test_configs = [
        (False, "mp", False, None, False),
        (False, "mp", False, spec_config, False),
        (True, "mp", False, spec_config, True),
        (True, "uni", False, spec_config_short, True),
        (False, "mp", True, spec_config, False),
        (True, "mp", True, spec_config, False),
        (False, "mp", True, spec_config_short, True),
        (True, "uni", True, spec_config, False),
        (True, "uni", True, spec_config_short, False),
120
121
        (True, "mp", True, spec_config, True),
        (True, "uni", True, spec_config_short, True),
122
123
    ]

124
125
126
127
128
129
130
131
    # On ROCm, use TRITON_ATTN + float32 for better numerical consistency
    run_tests(
        monkeypatch,
        MTP_MODEL,
        test_configs,
        test_sampling_params,
        is_testing_with_spec_decoding=True,
    )
132
133
134
135
136
137
138
139


@dynamo_config.patch(cache_size_limit=16)
def run_tests(
    monkeypatch: pytest.MonkeyPatch,
    model: str,
    test_configs: list[tuple],
    test_sampling_params: list[dict[str, Any]],
140
    is_testing_with_spec_decoding: bool = False,
141
142
143
144
):
    """Test consistency of combos of async scheduling, preemption,
    uni/multiproc executor with spec decoding."""

145
    with monkeypatch.context() as m:
146
        # avoid precision errors
147
148
149
150
151
152
153
154
        if current_platform.is_rocm():
            if is_testing_with_spec_decoding:
                # Use TRITON_ATTN for spec decoding test for consistency
                m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
            else:
                m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
        else:
            m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
155
156
        # lock matmul precision to full FP32
        m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
157
        # m.setenv("VLLM_BATCH_INVARIANT", "1")
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        outputs: list[tuple[str, list, list]] = []
        for n, (
            test_preemption,
            executor,
            async_scheduling,
            spec_config,
            test_prefill_chunking,
        ) in enumerate(test_configs, 1):
            test_str = f"{n}/{len(test_configs)}"
            test_results = run_test(
                model,
                test_str,
                test_sampling_params,
                test_preemption,
                executor,
                async_scheduling,
                spec_config,
                test_prefill_chunking=test_prefill_chunking,
176
                is_testing_with_spec_decoding=is_testing_with_spec_decoding,
177
178
179
180
181
182
183
            )
            outputs.append(test_results)

    baseline_config, baseline_tests, _ = outputs[0]
    _, _, baseline_acceptances = next(
        (o for o in outputs if o[2] is not None), (None, None, None)
    )
184

185
186
187
188
189
190
191
192
193
194
195
196
197
    print(f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}")

    failure = None
    for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
        for (base_outs, base_logprobs), base_acceptance_rate, (
            test_outs,
            test_logprobs,
        ), test_acceptance_rate, params in zip(
            baseline_tests,
            baseline_acceptances or repeat(None),
            test_outputs,
            test_acceptance_rates or repeat(None),
            test_sampling_params,
198
        ):
199
200
201
202
203
204
205
            try:
                check_outputs_equal(
                    outputs_0_lst=base_outs,
                    outputs_1_lst=test_outs,
                    name_0=f"baseline=[{baseline_config}], params={params}",
                    name_1=f"config=[{test_config}], params={params}",
                )
206
207
208
209
210
211
212
213
214
215

                # On ROCm with TRITON_ATTN (spec decoding test), skip strict
                # logprobs comparison when logprobs are requested
                skip_logprobs_check = (
                    current_platform.is_rocm()
                    and params.get("logprobs")
                    and is_testing_with_spec_decoding
                )
                if not skip_logprobs_check:
                    assert _all_logprobs_match(base_logprobs, test_logprobs)
216
217
218
219
220
221

                if (
                    base_acceptance_rate is not None
                    and test_acceptance_rate is not None
                ):
                    if "spec_mml=None" in test_config:
222
223
224
225
226
227
228
229
                        # Preemption causes more variance in acceptance rates
                        if (
                            current_platform.is_rocm()
                            and "preemption=True" in test_config
                        ):
                            tolerance = 0.10
                        else:
                            tolerance = 0.05
230
                        assert (
231
232
                            test_acceptance_rate > base_acceptance_rate
                            or test_acceptance_rate
233
                            == pytest.approx(base_acceptance_rate, rel=tolerance)
234
235
236
                        )
                    else:
                        # Currently the reported acceptance rate is expected to be
237
                        # lower when we sometimes skip drafting altogether.
238
                        assert test_acceptance_rate > 0.1
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
                print(
                    f"PASSED: config=[{test_config}], params={params}"
                    f" accept_rate={test_acceptance_rate}"
                )
            except AssertionError as e:
                print(
                    f"FAILED: config=[{test_config}], params={params}"
                    f" accept_rate={test_acceptance_rate}"
                )
                if failure is None:
                    failure = e

    if failure is not None:
        raise failure


def run_test(
    model: str,
    test_str: str,
    sampling_param_tests: list[dict[str, Any]],
    test_preemption: bool,
    executor: str,
    async_scheduling: bool,
    spec_config: dict[str, Any] | None,
    test_prefill_chunking: bool,
264
    is_testing_with_spec_decoding: bool = False,
265
266
267
):
    spec_decoding = spec_config is not None
    cache_arg: dict[str, Any] = (
268
        # Force preemptions
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        dict(num_gpu_blocks_override=32)
        if test_preemption
        else dict(gpu_memory_utilization=0.9)
    )
    spec_mml = (spec_config or {}).get("max_model_len")
    test_config = (
        f"executor={executor}, preemption={test_preemption}, "
        f"async_sched={async_scheduling}, "
        f"chunk_prefill={test_prefill_chunking}, "
        f"spec_decoding={spec_decoding}, spec_mml={spec_mml}"
    )
    print("-" * 80)
    print(f"---- TESTING {test_str}: {test_config}")
    print("-" * 80)
283
284
285
286
287
288
289
290
291

    # On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
    # spec decoding test (TRITON_ATTN) for better precision.
    # On others: always use float32.
    if current_platform.is_rocm() and not is_testing_with_spec_decoding:
        dtype = "float16"
    else:
        dtype = "float32"

292
293
294
295
    with VllmRunner(
        model,
        max_model_len=512,
        enable_chunked_prefill=test_prefill_chunking,
296
        # Force prefill chunking
297
298
299
300
        max_num_batched_tokens=48 if test_prefill_chunking else None,
        # enforce_eager=True,
        async_scheduling=async_scheduling,
        distributed_executor_backend=executor,
301
        dtype=dtype,
302
303
304
305
306
307
308
309
310
311
312
313
        speculative_config=spec_config,
        disable_log_stats=False,
        **cache_arg,
    ) as vllm_model:
        results = []
        acceptance_rates: list[float] | None = [] if spec_decoding else None
        for override_params in sampling_param_tests:
            metrics_before = vllm_model.llm.get_metrics()
            print(f"----------- RUNNING PARAMS: {override_params}")
            results.append(
                vllm_model.generate(
                    example_prompts,
314
                    sampling_params=SamplingParams(**default_params, **override_params),
315
316
                    return_logprobs=True,
                )
317
            )
318
319
320
321
322
323
324
325
            metrics_after = vllm_model.llm.get_metrics()
            if acceptance_rates is not None:
                acceptance_rate = _get_acceptance_rate(metrics_before, metrics_after)
                acceptance_rates.append(acceptance_rate)
                print(f"ACCEPTANCE RATE {acceptance_rate}")

            if test_preemption:
                preemptions = _get_count(
326
                    metrics_before, metrics_after, "vllm:num_preemptions"
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
                )
                assert preemptions > 0, "preemption test had no preemptions"

    if len(results) > 1:
        # First check that the different parameter configs
        # actually result in different output.
        for (other_test_outs, other_test_logprobs), params in zip(
            results[1:], sampling_param_tests[1:]
        ):
            with pytest.raises(AssertionError):
                check_outputs_equal(
                    outputs_0_lst=results[0][0],
                    outputs_1_lst=other_test_outs,
                    name_0=f"baseline params={params}",
                    name_1=f"other params={params}",
                )
                assert _all_logprobs_match(results[0][1], other_test_logprobs)
344

345
    return test_config, results, acceptance_rates
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360


def _all_logprobs_match(req_a, req_b) -> bool:
    return (
        req_a == req_b
        or len(req_a) == len(req_b)
        and all(
            len(seq_a) == len(seq_b)
            and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
            for seq_a, seq_b in zip(req_a, req_b)
        )
    )


def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    if current_platform.is_rocm():
        # ROCm has higher numerical variance
        # due to use of float16.
        rel_tol, abs_tol = 5e-2, 1e-5
    else:
        rel_tol, abs_tol = 1e-3, 1e-6
    return (
        len(lps_a) == len(lps_b)
        and lps_a.keys() == lps_b.keys()
        and all(
            a.decoded_token == b.decoded_token
            and a.rank == b.rank
            and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
            for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
        )
376
    )
377
378
379
380
381
382
383
384
385
386
387
388


def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
    draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
    accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
    return accept / draft if draft > 0 else 0.0


def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
    before_val = next(m.value for m in before if m.name == name)
    after_val = next(m.value for m in after if m.name == name)
    return after_val - before_val