"benchmarks/multi_turn/benchmark_serving_multi_turn.py" did not exist on "e789cad6b8b5d2a01aa6521b9208bb8d6501ee5b"
test_spec_decode.py 31.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import random
4
5
from collections.abc import Iterable
from dataclasses import dataclass
6
from typing import Any
7

8
import pytest
zhiweiz's avatar
zhiweiz committed
9
import torch
10

11
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
12
13
14
15
16
17
18
from tests.utils import (
    get_attn_backend_list_based_on_platform,
    large_gpu_mark,
    multi_gpu_marks,
    multi_gpu_only,
    single_gpu_only,
)
19
from vllm import LLM, SamplingParams
20
21
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
22
from vllm.benchmarks.datasets import InstructCoderDataset
23
from vllm.config import VllmConfig
zhiweiz's avatar
zhiweiz committed
24
from vllm.distributed import cleanup_dist_env_and_memory
25
from vllm.engine.arg_utils import EngineArgs
26
from vllm.platforms import current_platform
27
from vllm.v1.metrics.reader import Metric
28
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
29

30
31
MTP_SIMILARITY_RATE = 0.8

32

33
34
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
    """Skip test if available GPUs < tp_size on ROCm."""
35
36
37
38
39
    available_gpus = torch.cuda.device_count()
    if available_gpus < tp_size:
        pytest.skip(
            f"Test requires {tp_size} GPUs, but only {available_gpus} available"
        )
40
41


42
43
44
Messages = list[dict[str, Any]]


45
46
def get_test_prompts(mm_enabled: bool, num_prompts: int = 100) -> list[Messages]:
    prompt_types = ["repeat", "gsm8k"]
47
48
    if mm_enabled:
        prompt_types.append("mm")
49
    prompts: list[Messages] = []
50

51
52
53
54
55
56
57
    num_repeat_prompts = num_prompts // len(prompt_types)
    if mm_enabled:
        num_gsm8k_prompts = num_prompts // len(prompt_types)
        num_mm_prompts = num_prompts - num_repeat_prompts - num_gsm8k_prompts
    else:
        num_mm_prompts = 0
        num_gsm8k_prompts = num_prompts - num_repeat_prompts
58
59
60

    # Generate a mixed batch of prompts, some of which can be easily
    # predicted by n-gram matching and some which likely cannot.
61
62
    random.seed(0)
    for _ in range(num_repeat_prompts):
63
64
        word_choices = ["test", "temp", "hello", "where"]
        word = random.choice(word_choices)
65
66
        prompts.append(
            [
67
                {
68
69
70
71
72
73
                    "role": "user",
                    "content": f"""
        please repeat the word '{word}' 10 times.
        give no other output than the word at least ten times in a row,
        in lowercase with spaces between each word and without quotes.
        """,
74
75
                }
            ]
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        )
    prompts.extend(
        [{"role": "user", "content": prompt}]
        for prompt in _build_gsm8k_prompts(
            num_questions=num_gsm8k_prompts, num_shots=5
        )[0]
    )
    for _ in range(num_mm_prompts):
        placeholders = [
            {
                "type": "image_url",
                "image_url": {
                    "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
                },
            }
        ]
        prompt = [
            *placeholders,
            {"type": "text", "text": "The meaning of the image is"},
        ]
96
97
98
        prompts.append([{"role": "user", "content": prompt}])

    return prompts
99
100


101
102
103
104
105
106
107
108
def get_instruct_coder_messages(n: int) -> list[Messages]:
    dataset = InstructCoderDataset(
        dataset_path="likaixin/InstructCoder", dataset_split="train"
    )
    prompts: Iterable[str] = dataset.sample_prompts(n=n)
    return [[{"role": "user", "content": prompt}] for prompt in prompts]


109
110
@pytest.fixture
def sampling_config():
111
112
113
114
    return greedy_sampling()


def greedy_sampling() -> SamplingParams:
115
    return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
116
117


118
119
120
121
def stochastic_sampling() -> SamplingParams:
    return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)


122
123
@pytest.fixture
def model_name():
124
    return "meta-llama/Llama-3.1-8B-Instruct"
125
126


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def evaluate_llm_for_gsm8k(llm: LLM, expected_accuracy_threshold: float = 0.70) -> None:
    """Evaluate the LLM on GSM8K and check that accuracy is above a sanity threshold.

    The default threshold assumes the LLM uses the same target model as the "model_name"
    fixture, with max model len == 4096. Precomputed reference value is 75% to 80%
    on GSM8K with greedy decoding, so we check that it's above a sanity threshold of 70%
    to verify that the model is correct.
    """
    if expected_accuracy_threshold <= 0.0:
        print("Skipping GSM8K evaluation")
        return
    results = evaluate_gsm8k_offline(llm)
    accuracy = results["accuracy"]
    print(f"GSM8K accuracy: {accuracy:.3f}")
    assert accuracy >= expected_accuracy_threshold, (
        f"Expected GSM8K accuracy >= {expected_accuracy_threshold}, got {accuracy:.3f}"
    )


146
147
148
149
150
151
152
153
@pytest.fixture(autouse=True)
def reset_torch_dynamo():
    """Reset torch dynamo cache before each test"""
    yield
    # Cleanup after test
    torch._dynamo.reset()


154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@pytest.mark.parametrize(
    "speculative_config",
    [
        {
            "method": "ngram",
            "prompt_lookup_max": 5,
            "prompt_lookup_min": 3,
            "num_speculative_tokens": 3,
        },
        {
            "method": "suffix",
            "suffix_decoding_max_spec_factor": 2.0,
        },
    ],
)
169
170
@single_gpu_only
@large_gpu_mark(min_gb=20)
171
172
def test_ngram_and_suffix_correctness(
    speculative_config: dict,
173
174
    model_name: str,
):
175
176
    spec_llm = LLM(
        model=model_name,
177
        speculative_config=speculative_config,
178
        max_model_len=4096,
179
    )
180
    evaluate_llm_for_gsm8k(spec_llm)
181
182
183
184
185
    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()


186
187
@single_gpu_only
@large_gpu_mark(min_gb=20)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def test_suffix_decoding_acceptance(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_name: str,
):
    """
    Check that suffix decoding caching takes effect and improves acceptance
    lengths and acceptance rates over multiple runs of the same prompts.
    """
    test_prompts = get_test_prompts(mm_enabled=False)

    spec_llm = LLM(
        model=model_name,
        speculative_config={
            "method": "suffix",
            "suffix_decoding_max_spec_factor": 2.0,
            "suffix_decoding_max_cached_requests": 1000,
        },
        max_model_len=1024,
        disable_log_stats=False,
    )

    # Run several times and check that the accepted tokens increase.
    num_draft = []
    num_accept = []
    for i in range(10):  # Run multiple times to warm up the cache.
        spec_llm.chat(test_prompts, sampling_config)
        # Collect draft and acceptance stats.
        metrics = spec_llm.get_metrics()
        for metric in metrics:
            if metric.name == "vllm:spec_decode_num_draft_tokens":
                num_draft.append(metric.value)
            if metric.name == "vllm:spec_decode_num_accepted_tokens":
                num_accept.append(metric.value)

    # Calculate the acceptance rates for the first and last runs.
    first_accept_tokens = num_accept[0]
    first_draft_tokens = num_draft[0]
    first_accept_rate = first_accept_tokens / first_draft_tokens

    # Take the diff since the stats are cumulative.
    last_accept_tokens = num_accept[-1] - num_accept[-2]
    last_draft_tokens = num_draft[-1] - num_draft[-2]
    last_accept_rate = last_accept_tokens / last_draft_tokens

    # Expect the acceptance length to improve.
    assert first_accept_tokens < last_accept_tokens

    # Expect the acceptance rate to improve.
    assert first_accept_rate < last_accept_rate

239
240
    # Heuristic: expect at least 80.0% acceptance rate at the end.
    assert last_accept_rate > 0.80
241
242
243
244
245
246

    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()


247
@pytest.mark.parametrize(
248
    ["model_path", "expected_accuracy_threshold"],
249
    [
250
251
        ("RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", 0.7),  # ref: 75%-80%
        ("RedHatAI/Qwen3-8B-speculator.eagle3", 0.8),  # ref: 87%-92%
252
253
254
    ],
    ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
)
255
256
@single_gpu_only
@large_gpu_mark(min_gb=24)
257
258
259
260
def test_speculators_model_integration(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_path: str,
261
    expected_accuracy_threshold: float,
262
263
264
265
266
267
268
269
270
271
272
273
274
):
    """
    Test that speculators models work with the simplified integration.

    This verifies the `vllm serve <speculator-model>` use case where
    speculative config is automatically detected from the model config
    without requiring explicit --speculative-config argument.

    Tests:
    1. Speculator model is correctly detected
    2. Verifier model is extracted from speculator config
    3. Speculative decoding is automatically enabled
    4. Text generation works correctly
275
276
    5. GSM8k accuracy of the model passes a sanity check when speculative decoding on
    6. Output matches reference (non-speculative) generation
277
278
279
280
281
282
283
    """
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

    # Generate test prompts
    test_prompts = get_test_prompts(mm_enabled=False)

    # First run: Direct speculator model (simplified integration)
284
285
286
287
    spec_llm = LLM(model=model_path, max_model_len=4096)
    evaluate_llm_for_gsm8k(
        spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
    )
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    spec_outputs = spec_llm.chat(test_prompts, sampling_config)

    # Verify speculative config was auto-detected
    assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
        f"Speculative config should be auto-detected for {model_path}"
    )

    spec_config = spec_llm.llm_engine.vllm_config.speculative_config
    assert spec_config.num_speculative_tokens > 0, (
        f"Expected positive speculative tokens, "
        f"got {spec_config.num_speculative_tokens}"
    )

    # Verify draft model is set to the speculator model
    assert spec_config.model == model_path, (
        f"Draft model should be {model_path}, got {spec_config.model}"
    )

    # Extract verifier model for reference run
    verifier_model = spec_llm.llm_engine.vllm_config.model_config.model

    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Second run: Reference without speculative decoding
314
    ref_llm = LLM(model=verifier_model, max_model_len=4096)
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    ref_outputs = ref_llm.chat(test_prompts, sampling_config)
    del ref_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Compare outputs
    matches = sum(
        1
        for ref, spec in zip(ref_outputs, spec_outputs)
        if ref.outputs[0].text == spec.outputs[0].text
    )

    # Heuristic: expect at least 66% of prompts to match exactly
    assert matches >= int(0.66 * len(ref_outputs)), (
        f"Only {matches}/{len(ref_outputs)} outputs matched. "
        f"Expected at least {int(0.66 * len(ref_outputs))} matches."
    )


334
def _run_eagle_correctness(
335
336
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
zhiweiz's avatar
zhiweiz committed
337
    model_setup: tuple[str, str, str, int],
338
    mm_enabled: bool,
339
    expected_accuracy_threshold: float,
340
    enable_chunked_prefill: bool,
341
    model_impl: str,
342
    attn_backend: str,
343
):
344
    """
345
346
    Compare the outputs of an original LLM and a speculative LLM
    which should be the same when using eagle speculative decoding.
347
    """
348
349
350
    if attn_backend == "TREE_ATTN":
        pytest.skip(
            "TREE_ATTN is flaky in the test disable for now until it can be "
351
352
            "resolved (see https://github.com/vllm-project/vllm/issues/22922)"
        )
353
354
355
356
357
    if model_impl == "transformers":
        import transformers
        from packaging.version import Version

        installed = Version(transformers.__version__)
358
        required = Version("5.0.0")
359
360
361
362
363
        if installed < required:
            pytest.skip(
                "Eagle3 with the Transformers modeling backend requires "
                f"transformers>={required}, but got {installed}"
            )
364

365
    test_prompts = get_test_prompts(mm_enabled)
366

367
368
    if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
        if current_platform.is_rocm():
369
370
371
372
373
374
375
            print(
                "FLASH_ATTN for spec_decode not supported on "
                "ROCm currently. Changing to FLEX_ATTENTION backend."
            )
            attention_config = {"backend": "FLEX_ATTENTION"}
        else:
            attention_config = None
376
377
378
379
380
381
382
383
    else:
        attention_config = {"backend": attn_backend}

    if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
        pytest.skip(
            "TRITON_ATTN does not support "
            "multi-token eagle spec decode on current platform"
        )
384

385
386
    with monkeypatch.context() as m:
        m.setenv("VLLM_MLA_DISABLE", "1")
387

388
        if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
389
            if "deepseek" in model_setup[1].lower():
390
391
392
                m.setenv("VLLM_ROCM_USE_AITER", "1")
                m.delenv("VLLM_MLA_DISABLE", raising=False)
                attention_config = {"backend": "TRITON_MLA"}
393
394
            else:
                m.setenv("VLLM_ROCM_USE_AITER", "1")
395

zhiweiz's avatar
zhiweiz committed
396
        method, model_name, spec_model_name, tp_size = model_setup
397
398
        _skip_if_insufficient_gpus_for_tp(tp_size)

399
        max_model_len = 2048
400
        max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
401

402
        ref_llm = LLM(
403
404
405
406
            model=model_name,
            max_model_len=max_model_len,
            tensor_parallel_size=tp_size,
            attention_config=attention_config,
407
        )
408
409
410
        evaluate_llm_for_gsm8k(
            ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
        )
411
412
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
zhiweiz's avatar
zhiweiz committed
413
414
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
415
416
417

        spec_llm = LLM(
            model=model_name,
418
            trust_remote_code=True,
zhiweiz's avatar
zhiweiz committed
419
            tensor_parallel_size=tp_size,
420
            speculative_config={
zhiweiz's avatar
zhiweiz committed
421
                "method": method,
422
                "model": spec_model_name,
423
                "num_speculative_tokens": 3,
424
                "max_model_len": max_model_len,
425
            },
426
427
            max_model_len=max_model_len,
            max_num_batched_tokens=max_num_batched_tokens,
428
            enable_chunked_prefill=enable_chunked_prefill,
429
            model_impl=model_impl,
430
            attention_config=attention_config,
431
        )
432
433
434
        evaluate_llm_for_gsm8k(
            spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
        )
435
436
437
438
439
440
441
442
443
444
445
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

446
        assert matches > int(0.6 * len(ref_outputs))
447
        del spec_llm
zhiweiz's avatar
zhiweiz committed
448
449
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
450
451


452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
@single_gpu_only
@pytest.mark.parametrize(
    [
        "model_setup",
        "mm_enabled",
        "enable_chunked_prefill",
        "model_impl",
        "expected_accuracy_threshold",
    ],
    [
        (
            (
                "eagle",
                "eagle618/deepseek-v3-random",
                "eagle618/eagle-deepseek-v3-random",
                1,
            ),
            False,
            False,
            "auto",
            0.0,
        ),
    ],
    ids=["deepseek_eagle"],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness_light(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_setup: tuple[str, str, str, int],
    mm_enabled: bool,
    expected_accuracy_threshold: float,
    enable_chunked_prefill: bool,
    model_impl: str,
    attn_backend: str,
):
    _run_eagle_correctness(
        monkeypatch,
        sampling_config,
        model_setup,
        mm_enabled,
        expected_accuracy_threshold,
        enable_chunked_prefill,
        model_impl,
        attn_backend,
    )


@single_gpu_only
@large_gpu_mark(min_gb=24)
@pytest.mark.parametrize(
    [
        "model_setup",
        "mm_enabled",
        "enable_chunked_prefill",
        "model_impl",
        "expected_accuracy_threshold",
    ],
    [
        (
            ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
            False,
            False,
            "auto",
            0.8,
        ),
        (
            ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
            False,
            False,
            "transformers",
            0.8,
        ),
        pytest.param(
            (
                "eagle3",
                "Qwen/Qwen3-VL-8B-Instruct",
                "taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
                1,
            ),
            False,
            False,
            "auto",
            0.8,
            marks=pytest.mark.skip(
                reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
            ),
        ),
        pytest.param(
            (
                "eagle3",
                "Qwen/Qwen2.5-VL-7B-Instruct",
                "Rayzl/qwen2.5-vl-7b-eagle3-sgl",
                1,
            ),
            False,
            False,
            "auto",
            0.7,
            marks=pytest.mark.skip(
                reason="Skipping due to its head_dim not being a multiple of 32"
            ),
        ),
        (
            (
                "eagle3",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
                1,
            ),
            False,
            False,
            "auto",
            0.7,
        ),
    ],
    ids=[
        "qwen3_eagle3",
        "qwen3_eagle3-transformers",
        "qwen3_vl_eagle3",
        "qwen2_5_vl_eagle3",
        "llama3_eagle3",
    ],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness_medium(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_setup: tuple[str, str, str, int],
    mm_enabled: bool,
    expected_accuracy_threshold: float,
    enable_chunked_prefill: bool,
    model_impl: str,
    attn_backend: str,
):
    _run_eagle_correctness(
        monkeypatch,
        sampling_config,
        model_setup,
        mm_enabled,
        expected_accuracy_threshold,
        enable_chunked_prefill,
        model_impl,
        attn_backend,
    )


@pytest.mark.parametrize(
    [
        "model_setup",
        "mm_enabled",
        "enable_chunked_prefill",
        "model_impl",
        "expected_accuracy_threshold",
    ],
    [
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
                1,
            ),
            False,
            True,
            "auto",
            0.7,
            marks=large_gpu_mark(min_gb=40),
            id="llama3_eagle",
        ),
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
                4,
            ),
            False,
            False,
            "auto",
            0.8,
            marks=multi_gpu_marks(num_gpus=4),
            id="llama4_eagle",
        ),
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
                4,
            ),
            True,
            True,
            "auto",
            0.8,
            marks=[*multi_gpu_marks(num_gpus=4), large_gpu_mark(min_gb=80)],
            id="llama4_eagle_mm",
        ),
    ],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness_heavy(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_setup: tuple[str, str, str, int],
    mm_enabled: bool,
    expected_accuracy_threshold: float,
    enable_chunked_prefill: bool,
    model_impl: str,
    attn_backend: str,
):
    _run_eagle_correctness(
        monkeypatch,
        sampling_config,
        model_setup,
        mm_enabled,
        expected_accuracy_threshold,
        enable_chunked_prefill,
        model_impl,
        attn_backend,
    )


675
@pytest.mark.parametrize(
676
    ["model_setup", "mm_enabled", "expected_accuracy_threshold"],
677
    [
678
679
        (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False, 0.5),  # ref: 65%-70%
        (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False, 0.0),  # dummy model
680
681
682
    ],
    ids=["mimo", "deepseek"],
)
683
684
@single_gpu_only
@large_gpu_mark(min_gb=20)
685
686
687
688
689
def test_mtp_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_setup: tuple[str, str, int],
    mm_enabled: bool,
690
    expected_accuracy_threshold: float,
691
):
692
    """
693
    Compare the outputs of a original LLM and a speculative LLM
694
695
696
697
    which should be the same when using MTP speculative decoding. Due to some variance
    in the engine, it is possible for some outputs to differ, so we expect that at least
    6/10 output tokens match exactly, and that the GSM8k accuracy is above a precomputed
    reference threshold for each model.
698
    """
699
700
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
701
702
703
704
    with monkeypatch.context() as m:
        m.setenv("VLLM_MLA_DISABLE", "1")

        method, model_name, tp_size = model_setup
705
        _skip_if_insufficient_gpus_for_tp(tp_size)
706

707
708
709
710
711
712
        ref_llm = LLM(
            model=model_name,
            max_model_len=2048,
            tensor_parallel_size=tp_size,
            trust_remote_code=True,
        )
713
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
714
715
716
        evaluate_llm_for_gsm8k(
            ref_llm, expected_accuracy_threshold=expected_accuracy_threshold
        )
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        del ref_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()

        spec_llm = LLM(
            model=model_name,
            trust_remote_code=True,
            tensor_parallel_size=tp_size,
            speculative_config={
                "method": method,
                "num_speculative_tokens": 1,
                "max_model_len": 2048,
            },
            max_model_len=2048,
        )
732
733
734
        evaluate_llm_for_gsm8k(
            spec_llm, expected_accuracy_threshold=expected_accuracy_threshold
        )
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

        # Heuristic: expect at least 80% of the prompts to match exactly
        # Upon failure, inspect the outputs to check for inaccuracy.
        assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
        del spec_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
752
753
754
755
756
757
758
759
760
761


@dataclass
class ArgsTest:
    target_model: str
    draft_model: str
    sampling_config: SamplingParams
    num_speculative_tokens: int
    expected_acceptance_rate: float
    expected_acceptance_len: float
762
    expected_gsm8k_accuracy: float = 0.0  # skip by default
763
    # Defaults
764
765
    enforce_eager: bool = True
    parallel_drafting: bool = False
766
767
    target_tensor_parallel_size: int = 1
    draft_tensor_parallel_size: int = 1
768
    max_model_len: int = 2048
769
770
771
772
773
774
775
776
777
778
779
780
    gpu_memory_utilization: float = 0.5
    dataset: str = "test_prompts"
    num_prompts: int = 100


cases = [
    # Same model for draft and target, greedy sampling.
    ArgsTest(
        target_model="Qwen/Qwen3-0.6B",
        draft_model="Qwen/Qwen3-0.6B",
        sampling_config=greedy_sampling(),
        num_speculative_tokens=3,  # K
781
782
783
        expected_acceptance_len=0.98 * (3 + 1),  # epsilon discount of K + 1
        expected_acceptance_rate=0.98,  # slight epsilon
        expected_gsm8k_accuracy=0.25,  # ref: 35-40%
784
785
786
787
788
789
790
    ),
    # Smaller draft model, stochastic sampling.
    ArgsTest(
        target_model="Qwen/Qwen3-1.7B",
        draft_model="Qwen/Qwen3-0.6B",
        sampling_config=stochastic_sampling(),
        num_speculative_tokens=3,
791
792
793
        expected_acceptance_len=3.4,  # ref: 3.7
        expected_acceptance_rate=0.80,  # ref: 0.90
        expected_gsm8k_accuracy=0.5,  # ref: 60%. Note gsm8k always runs greedy sampling
794
795
796
797
798
799
    ),
]


@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
800
@single_gpu_only
801
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
802
803
    args.enforce_eager = enforce_eager
    assert_draft_model_correctness(args)
804
805


806
@single_gpu_only
807
808
809
810
811
812
813
def test_draft_model_realistic_example():
    args = ArgsTest(
        target_model="Qwen/Qwen3-1.7B",
        draft_model="Qwen/Qwen3-0.6B",
        dataset="likaixin/InstructCoder",
        num_speculative_tokens=3,
        sampling_config=greedy_sampling(),
814
        enforce_eager=False,
815
816
        expected_acceptance_len=2.6,  # ref: 2.86
        expected_acceptance_rate=0.5,  # ref: 0.62
817
    )
818
819
820
    assert_draft_model_correctness(args)


821
@single_gpu_only
822
823
824
825
826
827
828
829
830
def test_draft_model_parallel_drafting():
    args = ArgsTest(
        target_model="Qwen/Qwen3-1.7B",
        draft_model="amd/PARD-Qwen3-0.6B",
        dataset="likaixin/InstructCoder",
        num_speculative_tokens=3,
        sampling_config=greedy_sampling(),
        parallel_drafting=True,
        enforce_eager=False,
831
832
        expected_acceptance_len=2.3,  # ref: 2.52
        expected_acceptance_rate=0.4,  # ref: 0.51
833
834
    )
    assert_draft_model_correctness(args)
835
836
837
838
839
840
841
842
843
844
845
846


@pytest.mark.parametrize(
    "models",
    [
        # target_model,         draft_model
        ("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"),  # target quantized
        ("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"),  # draft quantized
    ],
    ids=["target_quantized", "draft_quantized"],
)
@pytest.mark.parametrize("enforce_eager", [True, False])
847
@single_gpu_only
848
849
850
851
852
853
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
    tgt_model, draft_model = models
    sd_case = ArgsTest(
        target_model=tgt_model,
        draft_model=draft_model,
        **some_high_acceptance_metrics(),
854
        enforce_eager=enforce_eager,
855
    )
856
    assert_draft_model_correctness(sd_case)
857
858


859
@multi_gpu_only(num_gpus=2)
860
861
862
863
864
865
866
867
868
def test_draft_model_tensor_parallelism():
    """Ensure spec decode works when running with TP > 1."""
    _skip_if_insufficient_gpus_for_tp(2)
    sd_case = ArgsTest(
        target_model="Qwen/Qwen3-1.7B",
        target_tensor_parallel_size=2,
        draft_model="Qwen/Qwen3-0.6B",
        draft_tensor_parallel_size=2,
        **some_high_acceptance_metrics(),
869
        enforce_eager=False,
870
        expected_gsm8k_accuracy=0.5,
871
    )
872
    assert_draft_model_correctness(sd_case)
873
874


875
@multi_gpu_only(num_gpus=2)
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
def test_draft_model_engine_args_tensor_parallelism():
    """Ensure the vllm_config for the draft model is created correctly,
    and independently of the target model (quantization, TP, etc.)"""
    _skip_if_insufficient_gpus_for_tp(2)

    engine_args = EngineArgs(
        model="Qwen/Qwen3-1.7B-FP8",  # <<< tgt quantized
        tensor_parallel_size=2,
        speculative_config={
            "model": "Qwen/Qwen3-0.6B",  # <<< draft not quantized
            "method": "draft_model",
            "num_speculative_tokens": 3,
            "draft_tensor_parallel_size": 1,  # <<< valid arg name
        },
    )
    tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
    assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
    assert tgt_vllm_config.quant_config.get_name() == "fp8"

    draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
    assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
    assert draft_vllm_config.quant_config is None


def test_draft_model_engine_args_rejects_invalid_tp_argname():
    """The user should pass "draft_tensor_parallel_size" rather than
    "tensor_parallel_size". We enforce this with validation."""

    engine_args = EngineArgs(
        model="Qwen/Qwen3-1.7B",
        tensor_parallel_size=1,
        speculative_config={
            "model": "Qwen/Qwen3-0.6B",
            "method": "draft_model",
            "num_speculative_tokens": 3,
            "tensor_parallel_size": 1,  # <<< invalid arg name
        },
    )
    with pytest.raises(ValueError):
        engine_args.create_engine_config()


918
def assert_draft_model_correctness(args: ArgsTest):
919
920
921
922
923
924
925
926
927
928
929
930
931
    """Compare the outputs using and not using speculative decoding.
    In the greedy decoding case, the outputs must match EXACTLY."""
    test_prompts: list[Messages] = get_messages(
        dataset=args.dataset, n=args.num_prompts
    )

    spec_llm = LLM(
        model=args.target_model,
        speculative_config={
            "model": args.draft_model,
            "method": "draft_model",
            "num_speculative_tokens": args.num_speculative_tokens,
            "max_model_len": args.max_model_len,
932
            "enforce_eager": args.enforce_eager,
933
            "draft_tensor_parallel_size": args.draft_tensor_parallel_size,
934
            "parallel_drafting": args.parallel_drafting,
935
        },
936
        max_num_seqs=100,  # limit cudagraph capture runtime
937
938
939
        max_model_len=args.max_model_len,
        gpu_memory_utilization=args.gpu_memory_utilization,
        tensor_parallel_size=args.target_tensor_parallel_size,
940
        enforce_eager=args.enforce_eager,
941
942
943
944
945
946
947
        disable_log_stats=False,  # enables get_metrics()
    )
    # we don't check the outputs, only check the metrics
    spec_llm.chat(test_prompts, args.sampling_config)
    metrics = spec_llm.get_metrics()
    acceptance_rate: float = compute_acceptance_rate(metrics)
    acceptance_len: float = compute_acceptance_len(metrics)
948
949
950
951
952
953

    # Need to evaluate after getting metrics to avoid polluting the AR
    evaluate_llm_for_gsm8k(
        spec_llm, expected_accuracy_threshold=args.expected_gsm8k_accuracy
    )

954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
    del spec_llm  # CLEANUP
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    print(
        f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
        f"temperature={args.sampling_config.temperature:.2f}, "
        f"acceptance_rate={acceptance_rate:.2f}, "
        f"acceptance_len={acceptance_len:.2f}, "
    )

    assert acceptance_rate >= args.expected_acceptance_rate
    assert acceptance_len >= args.expected_acceptance_len


def get_messages(dataset: str, n: int) -> list[Messages]:
    if dataset == "test_prompts":
971
        return get_test_prompts(mm_enabled=False, num_prompts=n)
972
973
974
975
976
977
978
979
980
981
    elif dataset == "likaixin/InstructCoder":
        return get_instruct_coder_messages(n=n)
    else:
        raise NotImplementedError(f"Dataset '{dataset}' not implemented")


def some_high_acceptance_metrics() -> dict:
    return {
        "sampling_config": greedy_sampling(),
        "num_speculative_tokens": 3,
982
983
        "expected_acceptance_len": 3.4,  # ref: 3.75
        "expected_acceptance_rate": 0.8,  # ref: 0.9
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
    }


def compute_acceptance_rate(metrics: list[Metric]) -> float:
    name2metric = {metric.name: metric for metric in metrics}
    n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value  # type: ignore
    if n_draft_toks == 0:
        return float("nan")
    n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value  # type: ignore
    return n_accepted_toks / n_draft_toks


def compute_acceptance_len(metrics: list[Metric]) -> float:
    name2metric = {metric.name: metric for metric in metrics}
    n_drafts = name2metric["vllm:spec_decode_num_drafts"].value  # type: ignore
    n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value  # type: ignore
    if n_drafts == 0:
        return 1
    return 1 + (n_accepted_toks / n_drafts)