test_spec_decode.py 18.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import random
4
from typing import Any
5

6
import pytest
zhiweiz's avatar
zhiweiz committed
7
import torch
8

9
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
10
from vllm import LLM, SamplingParams
11
12
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
zhiweiz's avatar
zhiweiz committed
13
from vllm.distributed import cleanup_dist_env_and_memory
14
from vllm.platforms import current_platform
15

16
17
MTP_SIMILARITY_RATE = 0.8

18

19
20
21
22
23
24
25
26
27
28
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
    """Skip test if available GPUs < tp_size on ROCm."""
    if current_platform.is_rocm():
        available_gpus = torch.cuda.device_count()
        if available_gpus < tp_size:
            pytest.skip(
                f"Test requires {tp_size} GPUs, but only {available_gpus} available"
            )


29
def get_test_prompts(mm_enabled: bool):
30
    prompt_types = ["repeat", "sentence"]
31
32
    if mm_enabled:
        prompt_types.append("mm")
33
34
35
36
37
    num_prompts = 100
    prompts = []

    random.seed(0)
    random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
38
    print(f"Prompt types: {random_prompt_type_choices}")
39
40
41
42
43
44

    # Generate a mixed batch of prompts, some of which can be easily
    # predicted by n-gram matching and some which likely cannot.
    for kind in random_prompt_type_choices:
        word_choices = ["test", "temp", "hello", "where"]
        word = random.choice(word_choices)
45
        prompt: str | list[dict[str, Any]] = ""
46
47
48
49
50
51
52
53
54
55
56
57
        if kind == "repeat":
            prompt = f"""
            please repeat the word '{word}' 10 times.
            give no other output than the word at least ten times in a row,
            in lowercase with spaces between each word and without quotes.
            """
        elif kind == "sentence":
            prompt = f"""
            please give a ten-word sentence that
            uses the word {word} at least once.
            give no other output than that simple sentence without quotes.
            """
58
        elif kind == "mm":
59
60
61
62
63
64
65
66
            placeholders = [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
                    },
                }
            ]
67
68
            prompt = [
                *placeholders,
69
                {"type": "text", "text": "The meaning of the image is"},
70
            ]
71
72
73
74
75
        else:
            raise ValueError(f"Unknown prompt type: {kind}")
        prompts.append([{"role": "user", "content": prompt}])

    return prompts
76
77
78
79


@pytest.fixture
def sampling_config():
80
    return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
81
82
83
84


@pytest.fixture
def model_name():
85
    return "meta-llama/Llama-3.1-8B-Instruct"
86
87


88
89
90
91
92
93
94
95
@pytest.fixture(autouse=True)
def reset_torch_dynamo():
    """Reset torch dynamo cache before each test"""
    yield
    # Cleanup after test
    torch._dynamo.reset()


96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
@pytest.mark.parametrize(
    "speculative_config",
    [
        {
            "method": "ngram",
            "prompt_lookup_max": 5,
            "prompt_lookup_min": 3,
            "num_speculative_tokens": 3,
        },
        {
            "method": "suffix",
            "suffix_decoding_max_spec_factor": 2.0,
        },
    ],
)
def test_ngram_and_suffix_correctness(
    speculative_config: dict,
113
114
115
116
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_name: str,
):
117
    """
118
    Compare the outputs of an original LLM and a speculative LLM
119
    should be the same when using ngram speculative decoding.
120
    """
121
122
123
124
125
126
127
128
129
130
    test_prompts = get_test_prompts(mm_enabled=False)

    ref_llm = LLM(model=model_name, max_model_len=1024)
    ref_outputs = ref_llm.chat(test_prompts, sampling_config)
    del ref_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    spec_llm = LLM(
        model=model_name,
131
        speculative_config=speculative_config,
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        max_model_len=1024,
    )
    spec_outputs = spec_llm.chat(test_prompts, sampling_config)
    matches = 0
    misses = 0
    for ref_output, spec_output in zip(ref_outputs, spec_outputs):
        if ref_output.outputs[0].text == spec_output.outputs[0].text:
            matches += 1
        else:
            misses += 1
            print(f"ref_output: {ref_output.outputs[0].text}")
            print(f"spec_output: {spec_output.outputs[0].text}")

    # Heuristic: expect at least 66% of the prompts to match exactly
    # Upon failure, inspect the outputs to check for inaccuracy.
    assert matches >= int(0.66 * len(ref_outputs))
    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()


153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def test_suffix_decoding_acceptance(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_name: str,
):
    """
    Check that suffix decoding caching takes effect and improves acceptance
    lengths and acceptance rates over multiple runs of the same prompts.
    """
    test_prompts = get_test_prompts(mm_enabled=False)

    spec_llm = LLM(
        model=model_name,
        speculative_config={
            "method": "suffix",
            "suffix_decoding_max_spec_factor": 2.0,
            "suffix_decoding_max_cached_requests": 1000,
        },
        max_model_len=1024,
        disable_log_stats=False,
    )

    # Run several times and check that the accepted tokens increase.
    num_draft = []
    num_accept = []
    for i in range(10):  # Run multiple times to warm up the cache.
        spec_llm.chat(test_prompts, sampling_config)
        # Collect draft and acceptance stats.
        metrics = spec_llm.get_metrics()
        for metric in metrics:
            if metric.name == "vllm:spec_decode_num_draft_tokens":
                num_draft.append(metric.value)
            if metric.name == "vllm:spec_decode_num_accepted_tokens":
                num_accept.append(metric.value)

    # Calculate the acceptance rates for the first and last runs.
    first_accept_tokens = num_accept[0]
    first_draft_tokens = num_draft[0]
    first_accept_rate = first_accept_tokens / first_draft_tokens

    # Take the diff since the stats are cumulative.
    last_accept_tokens = num_accept[-1] - num_accept[-2]
    last_draft_tokens = num_draft[-1] - num_draft[-2]
    last_accept_rate = last_accept_tokens / last_draft_tokens

    # Expect the acceptance length to improve.
    assert first_accept_tokens < last_accept_tokens

    # Expect the acceptance rate to improve.
    assert first_accept_rate < last_accept_rate

204
205
    # Heuristic: expect at least 80.0% acceptance rate at the end.
    assert last_accept_rate > 0.80
206
207
208
209
210
211

    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()


212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
@pytest.mark.parametrize(
    "model_path",
    [
        "RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
        "RedHatAI/Qwen3-8B-speculator.eagle3",
    ],
    ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
)
def test_speculators_model_integration(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_path: str,
):
    """
    Test that speculators models work with the simplified integration.

    This verifies the `vllm serve <speculator-model>` use case where
    speculative config is automatically detected from the model config
    without requiring explicit --speculative-config argument.

    Tests:
    1. Speculator model is correctly detected
    2. Verifier model is extracted from speculator config
    3. Speculative decoding is automatically enabled
    4. Text generation works correctly
    5. Output matches reference (non-speculative) generation
    """
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

    # Generate test prompts
    test_prompts = get_test_prompts(mm_enabled=False)

    # First run: Direct speculator model (simplified integration)
    spec_llm = LLM(model=model_path, max_model_len=1024)
    spec_outputs = spec_llm.chat(test_prompts, sampling_config)

    # Verify speculative config was auto-detected
    assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
        f"Speculative config should be auto-detected for {model_path}"
    )

    spec_config = spec_llm.llm_engine.vllm_config.speculative_config
    assert spec_config.num_speculative_tokens > 0, (
        f"Expected positive speculative tokens, "
        f"got {spec_config.num_speculative_tokens}"
    )

    # Verify draft model is set to the speculator model
    assert spec_config.model == model_path, (
        f"Draft model should be {model_path}, got {spec_config.model}"
    )

    # Extract verifier model for reference run
    verifier_model = spec_llm.llm_engine.vllm_config.model_config.model

    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Second run: Reference without speculative decoding
    ref_llm = LLM(model=verifier_model, max_model_len=1024)
    ref_outputs = ref_llm.chat(test_prompts, sampling_config)
    del ref_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Compare outputs
    matches = sum(
        1
        for ref, spec in zip(ref_outputs, spec_outputs)
        if ref.outputs[0].text == spec.outputs[0].text
    )

    # Heuristic: expect at least 66% of prompts to match exactly
    assert matches >= int(0.66 * len(ref_outputs)), (
        f"Only {matches}/{len(ref_outputs)} outputs matched. "
        f"Expected at least {int(0.66 * len(ref_outputs))} matches."
    )


292
@pytest.mark.parametrize(
293
    ["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
294
    [
295
296
297
298
299
300
301
302
303
304
305
306
        (
            ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
            False,
            False,
            "auto",
        ),
        (
            ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
            False,
            False,
            "transformers",
        ),
307
308
309
310
311
312
313
314
315
        pytest.param(
            (
                "eagle3",
                "Qwen/Qwen3-VL-8B-Instruct",
                "taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
                1,
            ),
            False,
            False,
316
            "auto",
317
318
319
320
            marks=pytest.mark.skip(
                reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
            ),
        ),
321
322
323
324
325
326
327
328
        pytest.param(
            (
                "eagle3",
                "Qwen/Qwen2.5-VL-7B-Instruct",
                "Rayzl/qwen2.5-vl-7b-eagle3-sgl",
                1,
            ),
            False,
329
            False,
330
            "auto",
331
332
333
334
            marks=pytest.mark.skip(
                reason="Skipping due to its head_dim not being a a multiple of 32"
            ),
        ),
335
        pytest.param(
336
337
338
339
340
341
342
            (
                "eagle",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
                1,
            ),
            False,
343
            True,
344
            "auto",
345
346
            marks=large_gpu_mark(min_gb=40),
        ),  # works on 4x H100
347
348
349
350
351
352
353
354
        (
            (
                "eagle3",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
                1,
            ),
            False,
355
            False,
356
            "auto",
357
358
359
360
361
362
363
364
365
        ),
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
                4,
            ),
            False,
366
            False,
367
            "auto",
368
369
370
371
372
373
374
375
376
377
            marks=large_gpu_mark(min_gb=80),
        ),  # works on 4x H100
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
                4,
            ),
            True,
378
            True,
379
            "auto",
380
381
382
383
384
385
386
387
388
389
            marks=large_gpu_mark(min_gb=80),
        ),  # works on 4x H100
        (
            (
                "eagle",
                "eagle618/deepseek-v3-random",
                "eagle618/eagle-deepseek-v3-random",
                1,
            ),
            False,
390
            False,
391
            "auto",
392
        ),
393
394
    ],
    ids=[
395
        "qwen3_eagle3",
396
        "qwen3_eagle3-transformers",
397
        "qwen3_vl_eagle3",
398
399
400
401
402
403
404
405
406
        "qwen2_5_vl_eagle3",
        "llama3_eagle",
        "llama3_eagle3",
        "llama4_eagle",
        "llama4_eagle_mm",
        "deepseek_eagle",
    ],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
407
408
409
def test_eagle_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
zhiweiz's avatar
zhiweiz committed
410
    model_setup: tuple[str, str, str, int],
411
    mm_enabled: bool,
412
    enable_chunked_prefill: bool,
413
    model_impl: str,
414
    attn_backend: str,
415
):
416
417
418
419
    if attn_backend == "TREE_ATTN":
        # TODO: Fix this flaky test
        pytest.skip(
            "TREE_ATTN is flaky in the test disable for now until it can be "
420
421
            "resolved (see https://github.com/vllm-project/vllm/issues/22922)"
        )
422
423
424
425
426
427
428
429
430
431
432
    if model_impl == "transformers":
        import transformers
        from packaging.version import Version

        installed = Version(transformers.__version__)
        required = Version("5.0.0.dev")
        if installed < required:
            pytest.skip(
                "Eagle3 with the Transformers modeling backend requires "
                f"transformers>={required}, but got {installed}"
            )
433

434
435
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
436
    """
437
438
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using eagle speculative decoding.
zhiweiz's avatar
zhiweiz committed
439
    model_setup: (method, model_name, eagle_model_name, tp_size)
440
    """
441
    with monkeypatch.context() as m:
442
443
444
445
        if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
            # Scout requires default backend selection
            # because vision encoder has head_dim 88 being incompatible
            #  with FLASH_ATTN and needs to fall back to Flex Attn
446
447
448
449
450

            # pass if not ROCm
            if current_platform.is_rocm():
                # TODO: Enable Flex Attn for spec_decode on ROCm
                pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
451
452
453
        else:
            m.setenv("VLLM_MLA_DISABLE", "1")
            m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
454

455
456
457
458
459
        if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
            pytest.skip(
                "TRITON_ATTN does not support "
                "multi-token eagle spec decode on current platform"
            )
460

461
        if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
462
            if "deepseek" in model_setup[1].lower():
463
                pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
464
465
            else:
                m.setenv("VLLM_ROCM_USE_AITER", "1")
466

zhiweiz's avatar
zhiweiz committed
467
        method, model_name, spec_model_name, tp_size = model_setup
468
469
        _skip_if_insufficient_gpus_for_tp(tp_size)

470
        max_model_len = 2048
471
        max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
472

473
        ref_llm = LLM(
474
            model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
475
        )
476
477
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
zhiweiz's avatar
zhiweiz committed
478
479
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
480
481
482

        spec_llm = LLM(
            model=model_name,
483
            trust_remote_code=True,
zhiweiz's avatar
zhiweiz committed
484
            tensor_parallel_size=tp_size,
485
            speculative_config={
zhiweiz's avatar
zhiweiz committed
486
                "method": method,
487
                "model": spec_model_name,
488
                "num_speculative_tokens": 3,
489
                "max_model_len": max_model_len,
490
            },
491
492
            max_model_len=max_model_len,
            max_num_batched_tokens=max_num_batched_tokens,
493
            enable_chunked_prefill=enable_chunked_prefill,
494
            model_impl=model_impl,
495
496
497
498
499
500
501
502
503
504
505
506
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

507
        # Heuristic: expect at least 60% of the prompts to match exactly
508
        # Upon failure, inspect the outputs to check for inaccuracy.
509
        assert matches > int(0.6 * len(ref_outputs))
510
        del spec_llm
zhiweiz's avatar
zhiweiz committed
511
512
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
513
514


515
516
517
518
519
520
521
522
@pytest.mark.parametrize(
    ["model_setup", "mm_enabled"],
    [
        (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
        (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
    ],
    ids=["mimo", "deepseek"],
)
523
524
525
526
527
528
529
530
def test_mtp_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_setup: tuple[str, str, int],
    mm_enabled: bool,
):
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
531
    """
532
533
534
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using MTP speculative decoding.
    model_setup: (method, model_name, tp_size)
535
    """
536
537
538
539
    with monkeypatch.context() as m:
        m.setenv("VLLM_MLA_DISABLE", "1")

        method, model_name, tp_size = model_setup
540
        _skip_if_insufficient_gpus_for_tp(tp_size)
541

542
543
544
545
546
547
        ref_llm = LLM(
            model=model_name,
            max_model_len=2048,
            tensor_parallel_size=tp_size,
            trust_remote_code=True,
        )
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()

        spec_llm = LLM(
            model=model_name,
            trust_remote_code=True,
            tensor_parallel_size=tp_size,
            speculative_config={
                "method": method,
                "num_speculative_tokens": 1,
                "max_model_len": 2048,
            },
            max_model_len=2048,
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

        # Heuristic: expect at least 80% of the prompts to match exactly
        # Upon failure, inspect the outputs to check for inaccuracy.
        assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
        del spec_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()