test_async_scheduling.py 14.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from itertools import repeat
4
5
6
from typing import Any

import pytest
7
import torch._dynamo.config as dynamo_config
8

9
from tests.utils import large_gpu_mark, single_gpu_only
10
from vllm import SamplingParams
11
from vllm.logprobs import Logprob
12
from vllm.platforms import current_platform
13
from vllm.sampling_params import StructuredOutputsParams
14
from vllm.v1.metrics.reader import Metric
15
16
17
18
19

from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal

MODEL = "Qwen/Qwen3-0.6B"
20
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
21
22


23
24
25
26
27
28
29
30
first_prompt = (
    "The following numbers of the sequence "
    + ", ".join(str(i) for i in range(10))
    + " are:"
)
example_prompts = [first_prompt, "In one word, the capital of France is "] + [
    f"Tell me about the number {i}: " for i in range(32)
]
31

32
33
default_params = dict(
    temperature=0.0,  # greedy
34
    max_tokens=30,
35
    min_tokens=28,
36
)
37

38

39
@single_gpu_only
40
41
42
43
44
45
46
47
def test_without_spec_decoding(
    sample_json_schema,
    monkeypatch: pytest.MonkeyPatch,
):
    """Test consistency of combos of async scheduling, preemption,
    uni/multiproc executor, prefill chunking."""
    struct_outputs = StructuredOutputsParams(json=sample_json_schema)
    test_sampling_params: list[dict[str, Any]] = [
48
49
        dict(),
        # dict(min_tokens=20),
50
51
        dict(presence_penalty=-1.0),
        dict(bad_words=["the", " the"]),
52
53
        dict(logprobs=2),
        dict(logprobs=2, presence_penalty=-1.0),
54
        dict(structured_outputs=struct_outputs),
55
56
57
58
59
60
61
62
        dict(
            structured_outputs=struct_outputs,
            logprobs=2,
        ),
        dict(
            structured_outputs=struct_outputs,
            presence_penalty=-1.0,
        ),
63
        dict(
64
            structured_outputs=struct_outputs,
65
66
67
            logprobs=2,
            presence_penalty=-1.0,
        ),
68
69
    ]

70
71
72
73
74
75
76
77
78
79
    # test_preemption, executor, async_scheduling,
    # spec_config, test_prefill_chunking
    test_configs = [
        (False, "mp", False, None, False),
        (True, "mp", False, None, True),
        (False, "mp", True, None, False),
        (False, "uni", True, None, False),
        (True, "mp", True, None, False),
        (True, "uni", True, None, False),
        (False, "mp", True, None, True),
80
81
        (True, "mp", True, None, True),
        (True, "uni", True, None, True),
82
83
    ]

84
85
86
87
88
89
90
91
92
93
94
95
    if current_platform.is_rocm():
        # On ROCm, Only test with structured_outputs (deterministic)
        # and skip chunk_prefill (more variable).
        test_configs = [
            cfg
            for cfg in test_configs
            if not cfg[4]  # skip chunk_prefill=True
        ]
        test_sampling_params = [
            p for p in test_sampling_params if p.get("structured_outputs") is not None
        ]

96
    run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
97

98

99
100
@single_gpu_only
@large_gpu_mark(min_gb=16)
101
def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
102
103
104
105
106
107
    """Test consistency and acceptance rates with some different combos of
    preemption, executor, async scheduling, prefill chunking,
    spec decoding model length.
    """

    spec_config = {
108
        "method": "eagle3",
109
        "num_speculative_tokens": 2,
110
        "model": "nm-testing/Llama3_2_1B_speculator.eagle3",
111
    }
112
    # Set small draft model len to force doesn't-fit-in-drafter case.
113
114
    spec_config_short = spec_config | {"max_model_len": 50}

115
116
    struct_outputs = StructuredOutputsParams(json=sample_json_schema)

117
118
    test_sampling_params = [
        dict(),
119
120
        dict(presence_penalty=-1.0),
        dict(bad_words=["the", " the"]),
121
        dict(logprobs=2),
122
        dict(logprobs=2, presence_penalty=-1.0),
123
124
125
126
        dict(structured_outputs=struct_outputs),
        dict(
            structured_outputs=struct_outputs,
            logprobs=2,
127
            presence_penalty=-1.0,
128
        ),
129
130
    ]

131
132
133
134
135
136
137
138
139
140
141
142
    # test_preemption, executor, async_scheduling,
    # spec_config, test_prefill_chunking
    test_configs = [
        (False, "mp", False, None, False),
        (False, "mp", False, spec_config, False),
        (True, "mp", False, spec_config, True),
        (True, "uni", False, spec_config_short, True),
        (False, "mp", True, spec_config, False),
        (True, "mp", True, spec_config, False),
        (False, "mp", True, spec_config_short, True),
        (True, "uni", True, spec_config, False),
        (True, "uni", True, spec_config_short, False),
143
144
        (True, "mp", True, spec_config, True),
        (True, "uni", True, spec_config_short, True),
145
146
    ]

147
148
149
150
151
152
153
154
    # On ROCm, use TRITON_ATTN + float32 for better numerical consistency
    run_tests(
        monkeypatch,
        MTP_MODEL,
        test_configs,
        test_sampling_params,
        is_testing_with_spec_decoding=True,
    )
155
156


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
    """Test ngram_gpu speculative decoding with different configurations.

    This test specifically validates ngram_gpu behavior with various:
    - Number of speculative tokens (2-6)
    - Prompt lookup window sizes (min/max)
    - Async scheduling enabled (as in production)
    - Different executors and chunking settings
    """

    # Variant with larger speculation window
    ngram_gpu_config = {
        "method": "ngram_gpu",
        "num_speculative_tokens": 3,
        "prompt_lookup_max": 3,
        "prompt_lookup_min": 2,
    }

    # Test configurations covering various scenarios
    # test_preemption, executor, async_scheduling,
    # spec_config, test_prefill_chunking
    test_configs = [
        (False, "mp", False, None, False),
        (False, "mp", False, ngram_gpu_config, False),
        (True, "mp", False, ngram_gpu_config, True),
        (False, "mp", True, ngram_gpu_config, False),
        (True, "mp", True, ngram_gpu_config, False),
        (True, "uni", True, ngram_gpu_config, False),
        (True, "mp", True, ngram_gpu_config, True),
    ]

    # Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
    # and ngram_gpu doesn't require a specific draft model
    run_tests(monkeypatch, MODEL, test_configs, [{}])


193
194
195
196
197
198
@dynamo_config.patch(cache_size_limit=16)
def run_tests(
    monkeypatch: pytest.MonkeyPatch,
    model: str,
    test_configs: list[tuple],
    test_sampling_params: list[dict[str, Any]],
199
    is_testing_with_spec_decoding: bool = False,
200
201
202
203
):
    """Test consistency of combos of async scheduling, preemption,
    uni/multiproc executor with spec decoding."""

204
    # Determine attention config based on platform
205
    attention_config = {"backend": "FLEX_ATTENTION"}
206
207

    with monkeypatch.context() as m:
208
        # lock matmul precision to full FP32 (IEEE)
209
        m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
210
        # m.setenv("VLLM_BATCH_INVARIANT", "1")
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        outputs: list[tuple[str, list, list]] = []
        for n, (
            test_preemption,
            executor,
            async_scheduling,
            spec_config,
            test_prefill_chunking,
        ) in enumerate(test_configs, 1):
            test_str = f"{n}/{len(test_configs)}"
            test_results = run_test(
                model,
                test_str,
                test_sampling_params,
                test_preemption,
                executor,
                async_scheduling,
                spec_config,
                test_prefill_chunking=test_prefill_chunking,
229
                is_testing_with_spec_decoding=is_testing_with_spec_decoding,
230
                attention_config=attention_config,
231
232
233
234
235
236
237
            )
            outputs.append(test_results)

    baseline_config, baseline_tests, _ = outputs[0]
    _, _, baseline_acceptances = next(
        (o for o in outputs if o[2] is not None), (None, None, None)
    )
238

239
240
241
242
243
244
245
246
247
248
249
250
251
    print(f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}")

    failure = None
    for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
        for (base_outs, base_logprobs), base_acceptance_rate, (
            test_outs,
            test_logprobs,
        ), test_acceptance_rate, params in zip(
            baseline_tests,
            baseline_acceptances or repeat(None),
            test_outputs,
            test_acceptance_rates or repeat(None),
            test_sampling_params,
252
        ):
253
254
255
256
257
258
259
            try:
                check_outputs_equal(
                    outputs_0_lst=base_outs,
                    outputs_1_lst=test_outs,
                    name_0=f"baseline=[{baseline_config}], params={params}",
                    name_1=f"config=[{test_config}], params={params}",
                )
260

261
                assert _all_logprobs_match(base_logprobs, test_logprobs)
262
263
264
265
266
267

                if (
                    base_acceptance_rate is not None
                    and test_acceptance_rate is not None
                ):
                    if "spec_mml=None" in test_config:
268
269
270
271
272
273
274
275
                        # Preemption causes more variance in acceptance rates
                        if (
                            current_platform.is_rocm()
                            and "preemption=True" in test_config
                        ):
                            tolerance = 0.10
                        else:
                            tolerance = 0.05
276
                        assert (
277
278
                            test_acceptance_rate > base_acceptance_rate
                            or test_acceptance_rate
279
                            == pytest.approx(base_acceptance_rate, rel=tolerance)
280
281
282
                        )
                    else:
                        # Currently the reported acceptance rate is expected to be
283
                        # lower when we sometimes skip drafting altogether.
284
                        assert test_acceptance_rate > 0.1
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
                print(
                    f"PASSED: config=[{test_config}], params={params}"
                    f" accept_rate={test_acceptance_rate}"
                )
            except AssertionError as e:
                print(
                    f"FAILED: config=[{test_config}], params={params}"
                    f" accept_rate={test_acceptance_rate}"
                )
                if failure is None:
                    failure = e

    if failure is not None:
        raise failure


def run_test(
    model: str,
    test_str: str,
    sampling_param_tests: list[dict[str, Any]],
    test_preemption: bool,
    executor: str,
    async_scheduling: bool,
    spec_config: dict[str, Any] | None,
    test_prefill_chunking: bool,
310
    is_testing_with_spec_decoding: bool = False,
311
    attention_config: dict[str, Any] | None = None,
312
313
314
):
    spec_decoding = spec_config is not None
    cache_arg: dict[str, Any] = (
315
        # Force preemptions
316
317
318
319
320
        dict(num_gpu_blocks_override=32)
        if test_preemption
        else dict(gpu_memory_utilization=0.9)
    )
    spec_mml = (spec_config or {}).get("max_model_len")
321
    spec_method = (spec_config or {}).get("method", "none")
322
323
324
325
    test_config = (
        f"executor={executor}, preemption={test_preemption}, "
        f"async_sched={async_scheduling}, "
        f"chunk_prefill={test_prefill_chunking}, "
326
        f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}"
327
328
329
330
    )
    print("-" * 80)
    print(f"---- TESTING {test_str}: {test_config}")
    print("-" * 80)
331

332
333
    with VllmRunner(
        model,
334
        max_model_len=4096,
335
        enable_chunked_prefill=test_prefill_chunking,
336
        # Force prefill chunking
337
338
339
340
        max_num_batched_tokens=48 if test_prefill_chunking else None,
        # enforce_eager=True,
        async_scheduling=async_scheduling,
        distributed_executor_backend=executor,
341
        dtype="float32",
342
343
        speculative_config=spec_config,
        disable_log_stats=False,
344
        attention_config=attention_config,
345
346
347
348
349
350
351
352
353
354
        **cache_arg,
    ) as vllm_model:
        results = []
        acceptance_rates: list[float] | None = [] if spec_decoding else None
        for override_params in sampling_param_tests:
            metrics_before = vllm_model.llm.get_metrics()
            print(f"----------- RUNNING PARAMS: {override_params}")
            results.append(
                vllm_model.generate(
                    example_prompts,
355
                    sampling_params=SamplingParams(**default_params, **override_params),
356
357
                    return_logprobs=True,
                )
358
            )
359
360
361
362
363
364
365
366
            metrics_after = vllm_model.llm.get_metrics()
            if acceptance_rates is not None:
                acceptance_rate = _get_acceptance_rate(metrics_before, metrics_after)
                acceptance_rates.append(acceptance_rate)
                print(f"ACCEPTANCE RATE {acceptance_rate}")

            if test_preemption:
                preemptions = _get_count(
367
                    metrics_before, metrics_after, "vllm:num_preemptions"
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
                )
                assert preemptions > 0, "preemption test had no preemptions"

    if len(results) > 1:
        # First check that the different parameter configs
        # actually result in different output.
        for (other_test_outs, other_test_logprobs), params in zip(
            results[1:], sampling_param_tests[1:]
        ):
            with pytest.raises(AssertionError):
                check_outputs_equal(
                    outputs_0_lst=results[0][0],
                    outputs_1_lst=other_test_outs,
                    name_0=f"baseline params={params}",
                    name_1=f"other params={params}",
                )
                assert _all_logprobs_match(results[0][1], other_test_logprobs)
385

386
    return test_config, results, acceptance_rates
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401


def _all_logprobs_match(req_a, req_b) -> bool:
    return (
        req_a == req_b
        or len(req_a) == len(req_b)
        and all(
            len(seq_a) == len(seq_b)
            and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
            for seq_a, seq_b in zip(req_a, req_b)
        )
    )


def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
402
    rel_tol, abs_tol = 1e-3, 1e-6
403
404
405
406
407
408
409
410
411
    return (
        len(lps_a) == len(lps_b)
        and lps_a.keys() == lps_b.keys()
        and all(
            a.decoded_token == b.decoded_token
            and a.rank == b.rank
            and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
            for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
        )
412
    )
413
414
415
416
417
418
419
420
421
422
423
424


def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
    draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
    accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
    return accept / draft if draft > 0 else 0.0


def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
    before_val = next(m.value for m in before if m.name == name)
    after_val = next(m.value for m in after if m.name == name)
    return after_val - before_val