test_spec_decode.py 29.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import random
4
5
from collections.abc import Iterable
from dataclasses import dataclass
6
from typing import Any
7

8
import pytest
zhiweiz's avatar
zhiweiz committed
9
import torch
10

11
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
12
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
13
from vllm import LLM, SamplingParams
14
15
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
16
from vllm.benchmarks.datasets import InstructCoderDataset
17
from vllm.config import VllmConfig
zhiweiz's avatar
zhiweiz committed
18
from vllm.distributed import cleanup_dist_env_and_memory
19
from vllm.engine.arg_utils import EngineArgs
20
from vllm.platforms import current_platform
21
from vllm.v1.metrics.reader import Metric
22
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
23

24
25
MTP_SIMILARITY_RATE = 0.8

26

27
28
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
    """Skip test if available GPUs < tp_size on ROCm."""
29
30
31
32
33
    available_gpus = torch.cuda.device_count()
    if available_gpus < tp_size:
        pytest.skip(
            f"Test requires {tp_size} GPUs, but only {available_gpus} available"
        )
34
35


36
37
38
Messages = list[dict[str, Any]]


39
40
def get_test_prompts(mm_enabled: bool, num_prompts: int = 100) -> list[Messages]:
    prompt_types = ["repeat", "gsm8k"]
41
42
    if mm_enabled:
        prompt_types.append("mm")
43
    prompts: list[Messages] = []
44

45
46
47
48
49
50
51
    num_repeat_prompts = num_prompts // len(prompt_types)
    if mm_enabled:
        num_gsm8k_prompts = num_prompts // len(prompt_types)
        num_mm_prompts = num_prompts - num_repeat_prompts - num_gsm8k_prompts
    else:
        num_mm_prompts = 0
        num_gsm8k_prompts = num_prompts - num_repeat_prompts
52
53
54

    # Generate a mixed batch of prompts, some of which can be easily
    # predicted by n-gram matching and some which likely cannot.
55
56
    random.seed(0)
    for _ in range(num_repeat_prompts):
57
58
        word_choices = ["test", "temp", "hello", "where"]
        word = random.choice(word_choices)
59
60
        prompts.append(
            [
61
                {
62
63
64
65
66
67
                    "role": "user",
                    "content": f"""
        please repeat the word '{word}' 10 times.
        give no other output than the word at least ten times in a row,
        in lowercase with spaces between each word and without quotes.
        """,
68
69
                }
            ]
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        )
    prompts.extend(
        [{"role": "user", "content": prompt}]
        for prompt in _build_gsm8k_prompts(
            num_questions=num_gsm8k_prompts, num_shots=5
        )[0]
    )
    for _ in range(num_mm_prompts):
        placeholders = [
            {
                "type": "image_url",
                "image_url": {
                    "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
                },
            }
        ]
        prompt = [
            *placeholders,
            {"type": "text", "text": "The meaning of the image is"},
        ]
90
91
92
        prompts.append([{"role": "user", "content": prompt}])

    return prompts
93
94


95
96
97
98
99
100
101
102
def get_instruct_coder_messages(n: int) -> list[Messages]:
    dataset = InstructCoderDataset(
        dataset_path="likaixin/InstructCoder", dataset_split="train"
    )
    prompts: Iterable[str] = dataset.sample_prompts(n=n)
    return [[{"role": "user", "content": prompt}] for prompt in prompts]


103
104
@pytest.fixture
def sampling_config():
105
106
107
108
    return greedy_sampling()


def greedy_sampling() -> SamplingParams:
109
    return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
110
111


112
113
114
115
def stochastic_sampling() -> SamplingParams:
    return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)


116
117
@pytest.fixture
def model_name():
118
    return "meta-llama/Llama-3.1-8B-Instruct"
119
120


121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def evaluate_llm_for_gsm8k(llm: LLM, expected_accuracy_threshold: float = 0.70) -> None:
    """Evaluate the LLM on GSM8K and check that accuracy is above a sanity threshold.

    The default threshold assumes the LLM uses the same target model as the "model_name"
    fixture, with max model len == 4096. Precomputed reference value is 75% to 80%
    on GSM8K with greedy decoding, so we check that it's above a sanity threshold of 70%
    to verify that the model is correct.
    """
    if expected_accuracy_threshold <= 0.0:
        print("Skipping GSM8K evaluation")
        return
    results = evaluate_gsm8k_offline(llm)
    accuracy = results["accuracy"]
    print(f"GSM8K accuracy: {accuracy:.3f}")
    assert accuracy >= expected_accuracy_threshold, (
        f"Expected GSM8K accuracy >= {expected_accuracy_threshold}, got {accuracy:.3f}"
    )


140
141
142
143
144
145
146
147
@pytest.fixture(autouse=True)
def reset_torch_dynamo():
    """Reset torch dynamo cache before each test"""
    yield
    # Cleanup after test
    torch._dynamo.reset()


148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@pytest.mark.parametrize(
    "speculative_config",
    [
        {
            "method": "ngram",
            "prompt_lookup_max": 5,
            "prompt_lookup_min": 3,
            "num_speculative_tokens": 3,
        },
        {
            "method": "suffix",
            "suffix_decoding_max_spec_factor": 2.0,
        },
    ],
)
def test_ngram_and_suffix_correctness(
    speculative_config: dict,
165
166
    model_name: str,
):
167
168
    spec_llm = LLM(
        model=model_name,
169
        speculative_config=speculative_config,
170
        max_model_len=4096,
171
    )
172
    evaluate_llm_for_gsm8k(spec_llm)
173
174
175
176
177
    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def test_suffix_decoding_acceptance(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_name: str,
):
    """
    Check that suffix decoding caching takes effect and improves acceptance
    lengths and acceptance rates over multiple runs of the same prompts.
    """
    test_prompts = get_test_prompts(mm_enabled=False)

    spec_llm = LLM(
        model=model_name,
        speculative_config={
            "method": "suffix",
            "suffix_decoding_max_spec_factor": 2.0,
            "suffix_decoding_max_cached_requests": 1000,
        },
        max_model_len=1024,
        disable_log_stats=False,
    )

    # Run several times and check that the accepted tokens increase.
    num_draft = []
    num_accept = []
    for i in range(10):  # Run multiple times to warm up the cache.
        spec_llm.chat(test_prompts, sampling_config)
        # Collect draft and acceptance stats.
        metrics = spec_llm.get_metrics()
        for metric in metrics:
            if metric.name == "vllm:spec_decode_num_draft_tokens":
                num_draft.append(metric.value)
            if metric.name == "vllm:spec_decode_num_accepted_tokens":
                num_accept.append(metric.value)

    # Calculate the acceptance rates for the first and last runs.
    first_accept_tokens = num_accept[0]
    first_draft_tokens = num_draft[0]
    first_accept_rate = first_accept_tokens / first_draft_tokens

    # Take the diff since the stats are cumulative.
    last_accept_tokens = num_accept[-1] - num_accept[-2]
    last_draft_tokens = num_draft[-1] - num_draft[-2]
    last_accept_rate = last_accept_tokens / last_draft_tokens

    # Expect the acceptance length to improve.
    assert first_accept_tokens < last_accept_tokens

    # Expect the acceptance rate to improve.
    assert first_accept_rate < last_accept_rate

229
230
    # Heuristic: expect at least 80.0% acceptance rate at the end.
    assert last_accept_rate > 0.80
231
232
233
234
235
236

    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()


237
@pytest.mark.parametrize(
238
    ["model_path", "expected_accuracy_threshold"],
239
    [
240
241
        ("RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", 0.7),  # ref: 75%-80%
        ("RedHatAI/Qwen3-8B-speculator.eagle3", 0.8),  # ref: 87%-92%
242
243
244
245
246
247
248
    ],
    ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
)
def test_speculators_model_integration(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_path: str,
249
    expected_accuracy_threshold: float,
250
251
252
253
254
255
256
257
258
259
260
261
262
):
    """
    Test that speculators models work with the simplified integration.

    This verifies the `vllm serve <speculator-model>` use case where
    speculative config is automatically detected from the model config
    without requiring explicit --speculative-config argument.

    Tests:
    1. Speculator model is correctly detected
    2. Verifier model is extracted from speculator config
    3. Speculative decoding is automatically enabled
    4. Text generation works correctly
263
264
    5. GSM8k accuracy of the model passes a sanity check when speculative decoding on
    6. Output matches reference (non-speculative) generation
265
266
267
268
269
270
271
    """
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

    # Generate test prompts
    test_prompts = get_test_prompts(mm_enabled=False)

    # First run: Direct speculator model (simplified integration)
272
273
274
275
    spec_llm = LLM(model=model_path, max_model_len=4096)
    evaluate_llm_for_gsm8k(
        spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
    )
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    spec_outputs = spec_llm.chat(test_prompts, sampling_config)

    # Verify speculative config was auto-detected
    assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
        f"Speculative config should be auto-detected for {model_path}"
    )

    spec_config = spec_llm.llm_engine.vllm_config.speculative_config
    assert spec_config.num_speculative_tokens > 0, (
        f"Expected positive speculative tokens, "
        f"got {spec_config.num_speculative_tokens}"
    )

    # Verify draft model is set to the speculator model
    assert spec_config.model == model_path, (
        f"Draft model should be {model_path}, got {spec_config.model}"
    )

    # Extract verifier model for reference run
    verifier_model = spec_llm.llm_engine.vllm_config.model_config.model

    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Second run: Reference without speculative decoding
302
    ref_llm = LLM(model=verifier_model, max_model_len=4096)
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    ref_outputs = ref_llm.chat(test_prompts, sampling_config)
    del ref_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Compare outputs
    matches = sum(
        1
        for ref, spec in zip(ref_outputs, spec_outputs)
        if ref.outputs[0].text == spec.outputs[0].text
    )

    # Heuristic: expect at least 66% of prompts to match exactly
    assert matches >= int(0.66 * len(ref_outputs)), (
        f"Only {matches}/{len(ref_outputs)} outputs matched. "
        f"Expected at least {int(0.66 * len(ref_outputs))} matches."
    )


322
@pytest.mark.parametrize(
323
324
325
326
327
328
329
    [
        "model_setup",
        "mm_enabled",
        "enable_chunked_prefill",
        "model_impl",
        "expected_accuracy_threshold",
    ],
330
    [
331
332
333
334
335
        (
            ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
            False,
            False,
            "auto",
336
            0.8,  # ref: 90%
337
338
339
340
341
342
        ),
        (
            ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
            False,
            False,
            "transformers",
343
            0.8,  # ref: 90%
344
        ),
345
346
347
348
349
350
351
352
353
        pytest.param(
            (
                "eagle3",
                "Qwen/Qwen3-VL-8B-Instruct",
                "taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
                1,
            ),
            False,
            False,
354
            "auto",
355
            0.8,  # ref: 90%
356
357
358
359
            marks=pytest.mark.skip(
                reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
            ),
        ),
360
361
362
363
364
365
366
367
        pytest.param(
            (
                "eagle3",
                "Qwen/Qwen2.5-VL-7B-Instruct",
                "Rayzl/qwen2.5-vl-7b-eagle3-sgl",
                1,
            ),
            False,
368
            False,
369
            "auto",
370
            0.7,  # TODO, update this with a reference value when re-enabling this case
371
372
373
374
            marks=pytest.mark.skip(
                reason="Skipping due to its head_dim not being a a multiple of 32"
            ),
        ),
375
        pytest.param(
376
377
378
379
380
381
382
            (
                "eagle",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
                1,
            ),
            False,
383
            True,
384
            "auto",
385
            0.7,  # ref: 75%-80%
386
387
            marks=large_gpu_mark(min_gb=40),
        ),  # works on 4x H100
388
389
390
391
392
393
394
395
        (
            (
                "eagle3",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
                1,
            ),
            False,
396
            False,
397
            "auto",
398
            0.7,  # ref: 75%-80%
399
400
401
402
403
404
405
406
407
        ),
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
                4,
            ),
            False,
408
            False,
409
            "auto",
410
411
            0.8,  # ref: 90%
            # marks=large_gpu_mark(min_gb=80),
412
413
414
415
416
417
418
419
420
        ),  # works on 4x H100
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
                4,
            ),
            True,
421
            True,
422
            "auto",
423
            0.8,  # ref: 90%
424
425
426
427
428
429
430
431
432
433
            marks=large_gpu_mark(min_gb=80),
        ),  # works on 4x H100
        (
            (
                "eagle",
                "eagle618/deepseek-v3-random",
                "eagle618/eagle-deepseek-v3-random",
                1,
            ),
            False,
434
            False,
435
            "auto",
436
            0.0,  # dummy model, skip gsm8k check
437
        ),
438
439
    ],
    ids=[
440
        "qwen3_eagle3",
441
        "qwen3_eagle3-transformers",
442
        "qwen3_vl_eagle3",
443
444
445
446
447
448
449
450
451
        "qwen2_5_vl_eagle3",
        "llama3_eagle",
        "llama3_eagle3",
        "llama4_eagle",
        "llama4_eagle_mm",
        "deepseek_eagle",
    ],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
452
453
454
def test_eagle_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
zhiweiz's avatar
zhiweiz committed
455
    model_setup: tuple[str, str, str, int],
456
    mm_enabled: bool,
457
    expected_accuracy_threshold: float,
458
    enable_chunked_prefill: bool,
459
    model_impl: str,
460
    attn_backend: str,
461
):
462
463
464
465
466
467
468
    """
    Compare the outputs of a original LLM and a speculative LLM
    which should be the same when using eagle speculative decoding. Due to some variance
    in the engine, it is possible for some outputs to differ, so we expect that at least
    6/10 output tokens match exactly, and that the GSM8k accuracy is above
    a precomputed reference threshold for each model.
    """
469
470
471
472
    if attn_backend == "TREE_ATTN":
        # TODO: Fix this flaky test
        pytest.skip(
            "TREE_ATTN is flaky in the test disable for now until it can be "
473
474
            "resolved (see https://github.com/vllm-project/vllm/issues/22922)"
        )
475
476
477
478
479
    if model_impl == "transformers":
        import transformers
        from packaging.version import Version

        installed = Version(transformers.__version__)
480
        required = Version("5.0.0")
481
482
483
484
485
        if installed < required:
            pytest.skip(
                "Eagle3 with the Transformers modeling backend requires "
                f"transformers>={required}, but got {installed}"
            )
486

487
488
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    # Determine attention config
    # Scout requires default backend selection because vision encoder has
    # head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
    # to Flex Attn
    if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
        if current_platform.is_rocm():
            # TODO: Enable Flex Attn for spec_decode on ROCm
            pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
        attention_config = None  # Let it fall back to default
    else:
        attention_config = {"backend": attn_backend}

    if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
        pytest.skip(
            "TRITON_ATTN does not support "
            "multi-token eagle spec decode on current platform"
        )
506

507
508
    with monkeypatch.context() as m:
        m.setenv("VLLM_MLA_DISABLE", "1")
509

510
        if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
511
            if "deepseek" in model_setup[1].lower():
512
                pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
513
514
            else:
                m.setenv("VLLM_ROCM_USE_AITER", "1")
515

zhiweiz's avatar
zhiweiz committed
516
        method, model_name, spec_model_name, tp_size = model_setup
517
518
        _skip_if_insufficient_gpus_for_tp(tp_size)

519
        max_model_len = 2048
520
        max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
521

522
        ref_llm = LLM(
523
524
525
526
            model=model_name,
            max_model_len=max_model_len,
            tensor_parallel_size=tp_size,
            attention_config=attention_config,
527
        )
528
529
530
        evaluate_llm_for_gsm8k(
            ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
        )
531
532
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
zhiweiz's avatar
zhiweiz committed
533
534
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
535
536
537

        spec_llm = LLM(
            model=model_name,
538
            trust_remote_code=True,
zhiweiz's avatar
zhiweiz committed
539
            tensor_parallel_size=tp_size,
540
            speculative_config={
zhiweiz's avatar
zhiweiz committed
541
                "method": method,
542
                "model": spec_model_name,
543
                "num_speculative_tokens": 3,
544
                "max_model_len": max_model_len,
545
            },
546
547
            max_model_len=max_model_len,
            max_num_batched_tokens=max_num_batched_tokens,
548
            enable_chunked_prefill=enable_chunked_prefill,
549
            model_impl=model_impl,
550
            attention_config=attention_config,
551
        )
552
553
554
        evaluate_llm_for_gsm8k(
            spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
        )
555
556
557
558
559
560
561
562
563
564
565
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

566
        # Heuristic: expect at least 60% of the prompts to match exactly
567
        # Upon failure, inspect the outputs to check for inaccuracy.
568
        assert matches > int(0.6 * len(ref_outputs))
569
        del spec_llm
zhiweiz's avatar
zhiweiz committed
570
571
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
572
573


574
@pytest.mark.parametrize(
575
    ["model_setup", "mm_enabled", "expected_accuracy_threshold"],
576
    [
577
578
        (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False, 0.5),  # ref: 65%-70%
        (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False, 0.0),  # dummy model
579
580
581
    ],
    ids=["mimo", "deepseek"],
)
582
583
584
585
586
def test_mtp_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_setup: tuple[str, str, int],
    mm_enabled: bool,
587
    expected_accuracy_threshold: float,
588
):
589
    """
590
    Compare the outputs of a original LLM and a speculative LLM
591
592
593
594
    which should be the same when using MTP speculative decoding. Due to some variance
    in the engine, it is possible for some outputs to differ, so we expect that at least
    6/10 output tokens match exactly, and that the GSM8k accuracy is above a precomputed
    reference threshold for each model.
595
    """
596
597
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
598
599
600
601
    with monkeypatch.context() as m:
        m.setenv("VLLM_MLA_DISABLE", "1")

        method, model_name, tp_size = model_setup
602
        _skip_if_insufficient_gpus_for_tp(tp_size)
603

604
605
606
607
608
609
        ref_llm = LLM(
            model=model_name,
            max_model_len=2048,
            tensor_parallel_size=tp_size,
            trust_remote_code=True,
        )
610
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
611
612
613
        evaluate_llm_for_gsm8k(
            ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
        )
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        del ref_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()

        spec_llm = LLM(
            model=model_name,
            trust_remote_code=True,
            tensor_parallel_size=tp_size,
            speculative_config={
                "method": method,
                "num_speculative_tokens": 1,
                "max_model_len": 2048,
            },
            max_model_len=2048,
        )
629
630
631
        evaluate_llm_for_gsm8k(
            spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
        )
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

        # Heuristic: expect at least 80% of the prompts to match exactly
        # Upon failure, inspect the outputs to check for inaccuracy.
        assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
        del spec_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
649
650
651
652
653
654
655
656
657
658


@dataclass
class ArgsTest:
    target_model: str
    draft_model: str
    sampling_config: SamplingParams
    num_speculative_tokens: int
    expected_acceptance_rate: float
    expected_acceptance_len: float
659
    expected_gsm8k_accuracy: float = 0.0  # skip by default
660
    # Defaults
661
662
    enforce_eager: bool = True
    parallel_drafting: bool = False
663
664
    target_tensor_parallel_size: int = 1
    draft_tensor_parallel_size: int = 1
665
    max_model_len: int = 2048
666
667
668
669
670
671
672
673
674
675
676
677
    gpu_memory_utilization: float = 0.5
    dataset: str = "test_prompts"
    num_prompts: int = 100


cases = [
    # Same model for draft and target, greedy sampling.
    ArgsTest(
        target_model="Qwen/Qwen3-0.6B",
        draft_model="Qwen/Qwen3-0.6B",
        sampling_config=greedy_sampling(),
        num_speculative_tokens=3,  # K
678
679
680
        expected_acceptance_len=0.98 * (3 + 1),  # epsilon discount of K + 1
        expected_acceptance_rate=0.98,  # slight epsilon
        expected_gsm8k_accuracy=0.25,  # ref: 35-40%
681
682
683
684
685
686
687
    ),
    # Smaller draft model, stochastic sampling.
    ArgsTest(
        target_model="Qwen/Qwen3-1.7B",
        draft_model="Qwen/Qwen3-0.6B",
        sampling_config=stochastic_sampling(),
        num_speculative_tokens=3,
688
689
690
        expected_acceptance_len=3.4,  # ref: 3.7
        expected_acceptance_rate=0.80,  # ref: 0.90
        expected_gsm8k_accuracy=0.5,  # ref: 60%. Note gsm8k always runs greedy sampling
691
692
693
694
695
696
697
    ),
]


@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
698
699
    args.enforce_eager = enforce_eager
    assert_draft_model_correctness(args)
700
701
702
703
704
705
706
707
708


def test_draft_model_realistic_example():
    args = ArgsTest(
        target_model="Qwen/Qwen3-1.7B",
        draft_model="Qwen/Qwen3-0.6B",
        dataset="likaixin/InstructCoder",
        num_speculative_tokens=3,
        sampling_config=greedy_sampling(),
709
        enforce_eager=False,
710
711
        expected_acceptance_len=2.6,  # ref: 2.86
        expected_acceptance_rate=0.5,  # ref: 0.62
712
    )
713
714
715
716
717
718
719
720
721
722
723
724
    assert_draft_model_correctness(args)


def test_draft_model_parallel_drafting():
    args = ArgsTest(
        target_model="Qwen/Qwen3-1.7B",
        draft_model="amd/PARD-Qwen3-0.6B",
        dataset="likaixin/InstructCoder",
        num_speculative_tokens=3,
        sampling_config=greedy_sampling(),
        parallel_drafting=True,
        enforce_eager=False,
725
726
        expected_acceptance_len=2.3,  # ref: 2.52
        expected_acceptance_rate=0.4,  # ref: 0.51
727
728
    )
    assert_draft_model_correctness(args)
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746


@pytest.mark.parametrize(
    "models",
    [
        # target_model,         draft_model
        ("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"),  # target quantized
        ("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"),  # draft quantized
    ],
    ids=["target_quantized", "draft_quantized"],
)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
    tgt_model, draft_model = models
    sd_case = ArgsTest(
        target_model=tgt_model,
        draft_model=draft_model,
        **some_high_acceptance_metrics(),
747
        enforce_eager=enforce_eager,
748
    )
749
    assert_draft_model_correctness(sd_case)
750
751
752
753
754
755
756
757
758
759
760


def test_draft_model_tensor_parallelism():
    """Ensure spec decode works when running with TP > 1."""
    _skip_if_insufficient_gpus_for_tp(2)
    sd_case = ArgsTest(
        target_model="Qwen/Qwen3-1.7B",
        target_tensor_parallel_size=2,
        draft_model="Qwen/Qwen3-0.6B",
        draft_tensor_parallel_size=2,
        **some_high_acceptance_metrics(),
761
        enforce_eager=False,
762
        expected_gsm8k_accuracy=0.5,
763
    )
764
    assert_draft_model_correctness(sd_case)
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808


def test_draft_model_engine_args_tensor_parallelism():
    """Ensure the vllm_config for the draft model is created correctly,
    and independently of the target model (quantization, TP, etc.)"""
    _skip_if_insufficient_gpus_for_tp(2)

    engine_args = EngineArgs(
        model="Qwen/Qwen3-1.7B-FP8",  # <<< tgt quantized
        tensor_parallel_size=2,
        speculative_config={
            "model": "Qwen/Qwen3-0.6B",  # <<< draft not quantized
            "method": "draft_model",
            "num_speculative_tokens": 3,
            "draft_tensor_parallel_size": 1,  # <<< valid arg name
        },
    )
    tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
    assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
    assert tgt_vllm_config.quant_config.get_name() == "fp8"

    draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
    assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
    assert draft_vllm_config.quant_config is None


def test_draft_model_engine_args_rejects_invalid_tp_argname():
    """The user should pass "draft_tensor_parallel_size" rather than
    "tensor_parallel_size". We enforce this with validation."""

    engine_args = EngineArgs(
        model="Qwen/Qwen3-1.7B",
        tensor_parallel_size=1,
        speculative_config={
            "model": "Qwen/Qwen3-0.6B",
            "method": "draft_model",
            "num_speculative_tokens": 3,
            "tensor_parallel_size": 1,  # <<< invalid arg name
        },
    )
    with pytest.raises(ValueError):
        engine_args.create_engine_config()


809
def assert_draft_model_correctness(args: ArgsTest):
810
811
812
813
814
815
816
817
818
819
820
821
822
    """Compare the outputs using and not using speculative decoding.
    In the greedy decoding case, the outputs must match EXACTLY."""
    test_prompts: list[Messages] = get_messages(
        dataset=args.dataset, n=args.num_prompts
    )

    spec_llm = LLM(
        model=args.target_model,
        speculative_config={
            "model": args.draft_model,
            "method": "draft_model",
            "num_speculative_tokens": args.num_speculative_tokens,
            "max_model_len": args.max_model_len,
823
            "enforce_eager": args.enforce_eager,
824
            "draft_tensor_parallel_size": args.draft_tensor_parallel_size,
825
            "parallel_drafting": args.parallel_drafting,
826
        },
827
        max_num_seqs=100,  # limit cudagraph capture runtime
828
829
830
        max_model_len=args.max_model_len,
        gpu_memory_utilization=args.gpu_memory_utilization,
        tensor_parallel_size=args.target_tensor_parallel_size,
831
        enforce_eager=args.enforce_eager,
832
833
834
835
836
837
838
        disable_log_stats=False,  # enables get_metrics()
    )
    # we don't check the outputs, only check the metrics
    spec_llm.chat(test_prompts, args.sampling_config)
    metrics = spec_llm.get_metrics()
    acceptance_rate: float = compute_acceptance_rate(metrics)
    acceptance_len: float = compute_acceptance_len(metrics)
839
840
841
842
843
844

    # Need to evaluate after getting metrics to avoid polluting the AR
    evaluate_llm_for_gsm8k(
        spec_llm, expected_accuracy_threshold=args.expected_gsm8k_accuracy
    )

845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
    del spec_llm  # CLEANUP
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    print(
        f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
        f"temperature={args.sampling_config.temperature:.2f}, "
        f"acceptance_rate={acceptance_rate:.2f}, "
        f"acceptance_len={acceptance_len:.2f}, "
    )

    assert acceptance_rate >= args.expected_acceptance_rate
    assert acceptance_len >= args.expected_acceptance_len


def get_messages(dataset: str, n: int) -> list[Messages]:
    if dataset == "test_prompts":
862
        return get_test_prompts(mm_enabled=False, num_prompts=n)
863
864
865
866
867
868
869
870
871
872
    elif dataset == "likaixin/InstructCoder":
        return get_instruct_coder_messages(n=n)
    else:
        raise NotImplementedError(f"Dataset '{dataset}' not implemented")


def some_high_acceptance_metrics() -> dict:
    return {
        "sampling_config": greedy_sampling(),
        "num_speculative_tokens": 3,
873
874
        "expected_acceptance_len": 3.4,  # ref: 3.75
        "expected_acceptance_rate": 0.8,  # ref: 0.9
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
    }


def compute_acceptance_rate(metrics: list[Metric]) -> float:
    name2metric = {metric.name: metric for metric in metrics}
    n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value  # type: ignore
    if n_draft_toks == 0:
        return float("nan")
    n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value  # type: ignore
    return n_accepted_toks / n_draft_toks


def compute_acceptance_len(metrics: list[Metric]) -> float:
    name2metric = {metric.name: metric for metric in metrics}
    n_drafts = name2metric["vllm:spec_decode_num_drafts"].value  # type: ignore
    n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value  # type: ignore
    if n_drafts == 0:
        return 1
    return 1 + (n_accepted_toks / n_drafts)