test_async_scheduling.py 15.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
from itertools import repeat
5
6
7
from typing import Any

import pytest
8
import torch._dynamo.config as dynamo_config
9

10
from tests.utils import large_gpu_mark, single_gpu_only
11
from vllm import SamplingParams
12
from vllm.logprobs import Logprob
13
from vllm.platforms import current_platform
14
from vllm.sampling_params import StructuredOutputsParams
15
from vllm.v1.metrics.reader import Metric
16
17
18
19
20

from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal

MODEL = "Qwen/Qwen3-0.6B"
21
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
22

23
24
# Need to enforce eager for MRV2 while we sort out cudagraph issues.
ENFORCE_EAGER = os.getenv("ENFORCE_EAGER", "0") == "1"
25

26
27
28
29
30
31
32
33
first_prompt = (
    "The following numbers of the sequence "
    + ", ".join(str(i) for i in range(10))
    + " are:"
)
example_prompts = [first_prompt, "In one word, the capital of France is "] + [
    f"Tell me about the number {i}: " for i in range(32)
]
34

35
36
default_params = dict(
    temperature=0.0,  # greedy
37
    max_tokens=30,
38
    min_tokens=28,
39
)
40

41

42
@single_gpu_only
43
44
45
46
47
48
49
50
def test_without_spec_decoding(
    sample_json_schema,
    monkeypatch: pytest.MonkeyPatch,
):
    """Test consistency of combos of async scheduling, preemption,
    uni/multiproc executor, prefill chunking."""
    struct_outputs = StructuredOutputsParams(json=sample_json_schema)
    test_sampling_params: list[dict[str, Any]] = [
51
52
        dict(),
        # dict(min_tokens=20),
53
        dict(frequency_penalty=-1.0),
54
        dict(bad_words=["the", " the"]),
55
        dict(logprobs=2),
56
        dict(logprobs=2, frequency_penalty=-1.0),
57
        dict(structured_outputs=struct_outputs),
58
59
60
61
62
63
        dict(
            structured_outputs=struct_outputs,
            logprobs=2,
        ),
        dict(
            structured_outputs=struct_outputs,
64
            frequency_penalty=-1.0,
65
        ),
66
        dict(
67
            structured_outputs=struct_outputs,
68
            logprobs=2,
69
            frequency_penalty=-1.0,
70
        ),
71
72
    ]

73
74
75
76
77
78
79
80
81
82
    # test_preemption, executor, async_scheduling,
    # spec_config, test_prefill_chunking
    test_configs = [
        (False, "mp", False, None, False),
        (True, "mp", False, None, True),
        (False, "mp", True, None, False),
        (False, "uni", True, None, False),
        (True, "mp", True, None, False),
        (True, "uni", True, None, False),
        (False, "mp", True, None, True),
83
84
        (True, "mp", True, None, True),
        (True, "uni", True, None, True),
85
86
    ]

87
88
89
90
91
92
93
94
95
96
97
98
    if current_platform.is_rocm():
        # On ROCm, Only test with structured_outputs (deterministic)
        # and skip chunk_prefill (more variable).
        test_configs = [
            cfg
            for cfg in test_configs
            if not cfg[4]  # skip chunk_prefill=True
        ]
        test_sampling_params = [
            p for p in test_sampling_params if p.get("structured_outputs") is not None
        ]

99
    run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
100

101

102
103
@single_gpu_only
@large_gpu_mark(min_gb=16)
104
def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
105
106
107
108
109
110
    """Test consistency and acceptance rates with some different combos of
    preemption, executor, async scheduling, prefill chunking,
    spec decoding model length.
    """

    spec_config = {
111
        "method": "eagle3",
112
        "num_speculative_tokens": 2,
113
        "model": "nm-testing/Llama3_2_1B_speculator.eagle3",
114
    }
115
    # Set small draft model len to force doesn't-fit-in-drafter case.
116
117
    spec_config_short = spec_config | {"max_model_len": 50}

118
119
    struct_outputs = StructuredOutputsParams(json=sample_json_schema)

120
121
    test_sampling_params = [
        dict(),
122
        dict(frequency_penalty=-1.0),
123
        dict(bad_words=["the", " the"]),
124
        dict(logprobs=2),
125
        dict(logprobs=2, frequency_penalty=-1.0),
126
127
128
129
        dict(structured_outputs=struct_outputs),
        dict(
            structured_outputs=struct_outputs,
            logprobs=2,
130
            frequency_penalty=-1.0,
131
        ),
132
133
    ]

134
135
136
137
138
139
140
141
142
143
144
145
    # test_preemption, executor, async_scheduling,
    # spec_config, test_prefill_chunking
    test_configs = [
        (False, "mp", False, None, False),
        (False, "mp", False, spec_config, False),
        (True, "mp", False, spec_config, True),
        (True, "uni", False, spec_config_short, True),
        (False, "mp", True, spec_config, False),
        (True, "mp", True, spec_config, False),
        (False, "mp", True, spec_config_short, True),
        (True, "uni", True, spec_config, False),
        (True, "uni", True, spec_config_short, False),
146
147
        (True, "mp", True, spec_config, True),
        (True, "uni", True, spec_config_short, True),
148
149
    ]

150
    run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
151
152


153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
    """Test ngram_gpu speculative decoding with different configurations.

    This test specifically validates ngram_gpu behavior with various:
    - Number of speculative tokens (2-6)
    - Prompt lookup window sizes (min/max)
    - Async scheduling enabled (as in production)
    - Different executors and chunking settings
    """

    # Variant with larger speculation window
    ngram_gpu_config = {
        "method": "ngram_gpu",
        "num_speculative_tokens": 3,
        "prompt_lookup_max": 3,
        "prompt_lookup_min": 2,
    }

    # Test configurations covering various scenarios
    # test_preemption, executor, async_scheduling,
    # spec_config, test_prefill_chunking
    test_configs = [
        (False, "mp", False, None, False),
        (False, "mp", False, ngram_gpu_config, False),
        (True, "mp", False, ngram_gpu_config, True),
        (False, "mp", True, ngram_gpu_config, False),
        (True, "mp", True, ngram_gpu_config, False),
        (True, "uni", True, ngram_gpu_config, False),
        (True, "mp", True, ngram_gpu_config, True),
    ]

    # Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
    # and ngram_gpu doesn't require a specific draft model
    run_tests(monkeypatch, MODEL, test_configs, [{}])


189
190
191
192
193
194
195
196
197
198
@dynamo_config.patch(cache_size_limit=16)
def run_tests(
    monkeypatch: pytest.MonkeyPatch,
    model: str,
    test_configs: list[tuple],
    test_sampling_params: list[dict[str, Any]],
):
    """Test consistency of combos of async scheduling, preemption,
    uni/multiproc executor with spec decoding."""

199
    # Flex attention supports float32.
200
    attention_config = {"backend": "FLEX_ATTENTION"}
201
202

    with monkeypatch.context() as m:
203
        # lock matmul precision to full FP32 (IEEE)
204
        m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
205
        # m.setenv("VLLM_BATCH_INVARIANT", "1")
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        outputs: list[tuple[str, list, list]] = []
        for n, (
            test_preemption,
            executor,
            async_scheduling,
            spec_config,
            test_prefill_chunking,
        ) in enumerate(test_configs, 1):
            test_str = f"{n}/{len(test_configs)}"
            test_results = run_test(
                model,
                test_str,
                test_sampling_params,
                test_preemption,
                executor,
                async_scheduling,
                spec_config,
                test_prefill_chunking=test_prefill_chunking,
224
                attention_config=attention_config,
225
226
227
228
229
230
231
            )
            outputs.append(test_results)

    baseline_config, baseline_tests, _ = outputs[0]
    _, _, baseline_acceptances = next(
        (o for o in outputs if o[2] is not None), (None, None, None)
    )
232

233
234
235
236
237
238
239
240
241
242
243
244
245
    print(f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}")

    failure = None
    for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
        for (base_outs, base_logprobs), base_acceptance_rate, (
            test_outs,
            test_logprobs,
        ), test_acceptance_rate, params in zip(
            baseline_tests,
            baseline_acceptances or repeat(None),
            test_outputs,
            test_acceptance_rates or repeat(None),
            test_sampling_params,
246
        ):
247
            reason = None
248
249
250
251
252
253
254
            try:
                check_outputs_equal(
                    outputs_0_lst=base_outs,
                    outputs_1_lst=test_outs,
                    name_0=f"baseline=[{baseline_config}], params={params}",
                    name_1=f"config=[{test_config}], params={params}",
                )
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
            except AssertionError as e:
                reason = "outputs ", e

            if reason is None:
                try:
                    assert _all_logprobs_match(base_logprobs, test_logprobs)
                except AssertionError as e:
                    reason = "logprobs", e

            if reason is None:
                try:
                    if (
                        base_acceptance_rate is not None
                        and test_acceptance_rate is not None
                    ):
                        if "spec_mml=None" in test_config:
                            # Preemption causes more variance in acceptance rates
                            if (
                                current_platform.is_rocm()
                                and "preemption=True" in test_config
                            ):
                                tolerance = 0.10
                            else:
                                tolerance = 0.05
                            assert (
                                test_acceptance_rate > base_acceptance_rate
                                or test_acceptance_rate
                                == pytest.approx(base_acceptance_rate, rel=tolerance)
                            )
284
                        else:
285
286
287
288
289
290
291
                            # Currently the reported acceptance rate is expected to be
                            # lower when we sometimes skip drafting altogether.
                            assert test_acceptance_rate > 0.1
                except AssertionError as e:
                    reason = "accept  ", e

            if reason is None:
292
                print(
293
294
                    f"\033[32mPASSED\033[0m:           "
                    f"config=[{test_config}], params={params}"
295
296
                    f" accept_rate={test_acceptance_rate}"
                )
297
298
            else:
                reason_str, _ = reason
299
                print(
300
301
                    f"\033[31mFAILED\033[0m({reason_str}): "
                    f"config=[{test_config}], params={params}"
302
303
304
                    f" accept_rate={test_acceptance_rate}"
                )
                if failure is None:
305
                    _, failure = reason
306
307
308
309
310
311
312
313
314
315
316
317
318
319

    if failure is not None:
        raise failure


def run_test(
    model: str,
    test_str: str,
    sampling_param_tests: list[dict[str, Any]],
    test_preemption: bool,
    executor: str,
    async_scheduling: bool,
    spec_config: dict[str, Any] | None,
    test_prefill_chunking: bool,
320
    attention_config: dict[str, Any] | None = None,
321
322
323
):
    spec_decoding = spec_config is not None
    cache_arg: dict[str, Any] = (
324
        # Force preemptions
325
326
327
328
329
        dict(num_gpu_blocks_override=32)
        if test_preemption
        else dict(gpu_memory_utilization=0.9)
    )
    spec_mml = (spec_config or {}).get("max_model_len")
330
    spec_method = (spec_config or {}).get("method", "none")
331
332
333
334
    test_config = (
        f"executor={executor}, preemption={test_preemption}, "
        f"async_sched={async_scheduling}, "
        f"chunk_prefill={test_prefill_chunking}, "
335
        f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}"
336
337
338
339
    )
    print("-" * 80)
    print(f"---- TESTING {test_str}: {test_config}")
    print("-" * 80)
340

341
342
    with VllmRunner(
        model,
343
        max_model_len=4096,
344
        enable_chunked_prefill=test_prefill_chunking,
345
        # Force prefill chunking
346
        max_num_batched_tokens=48 if test_prefill_chunking else None,
347
        enforce_eager=ENFORCE_EAGER,
348
349
        async_scheduling=async_scheduling,
        distributed_executor_backend=executor,
350
        dtype="float32",
351
352
        speculative_config=spec_config,
        disable_log_stats=False,
353
        attention_config=attention_config,
354
355
356
357
358
359
360
361
362
363
        **cache_arg,
    ) as vllm_model:
        results = []
        acceptance_rates: list[float] | None = [] if spec_decoding else None
        for override_params in sampling_param_tests:
            metrics_before = vllm_model.llm.get_metrics()
            print(f"----------- RUNNING PARAMS: {override_params}")
            results.append(
                vllm_model.generate(
                    example_prompts,
364
                    sampling_params=SamplingParams(**default_params, **override_params),
365
366
                    return_logprobs=True,
                )
367
            )
368
369
370
371
372
373
374
375
            metrics_after = vllm_model.llm.get_metrics()
            if acceptance_rates is not None:
                acceptance_rate = _get_acceptance_rate(metrics_before, metrics_after)
                acceptance_rates.append(acceptance_rate)
                print(f"ACCEPTANCE RATE {acceptance_rate}")

            if test_preemption:
                preemptions = _get_count(
376
                    metrics_before, metrics_after, "vllm:num_preemptions"
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
                )
                assert preemptions > 0, "preemption test had no preemptions"

    if len(results) > 1:
        # First check that the different parameter configs
        # actually result in different output.
        for (other_test_outs, other_test_logprobs), params in zip(
            results[1:], sampling_param_tests[1:]
        ):
            with pytest.raises(AssertionError):
                check_outputs_equal(
                    outputs_0_lst=results[0][0],
                    outputs_1_lst=other_test_outs,
                    name_0=f"baseline params={params}",
                    name_1=f"other params={params}",
                )
                assert _all_logprobs_match(results[0][1], other_test_logprobs)
394

395
    return test_config, results, acceptance_rates
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410


def _all_logprobs_match(req_a, req_b) -> bool:
    return (
        req_a == req_b
        or len(req_a) == len(req_b)
        and all(
            len(seq_a) == len(seq_b)
            and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
            for seq_a, seq_b in zip(req_a, req_b)
        )
    )


def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
411
    rel_tol, abs_tol = 1e-3, 1e-6
412
413
414
415
416
417
418
419
420
    return (
        len(lps_a) == len(lps_b)
        and lps_a.keys() == lps_b.keys()
        and all(
            a.decoded_token == b.decoded_token
            and a.rank == b.rank
            and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
            for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
        )
421
    )
422
423
424
425
426
427
428
429
430
431
432
433


def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
    draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
    accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
    return accept / draft if draft > 0 else 0.0


def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
    before_val = next(m.value for m in before if m.name == name)
    after_val = next(m.value for m in after if m.name == name)
    return after_val - before_val