test_spec_decode.py 11.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import random
4
from typing import Any
5

6
import pytest
zhiweiz's avatar
zhiweiz committed
7
import torch
8

9
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
10
from vllm import LLM, SamplingParams
11
12
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
zhiweiz's avatar
zhiweiz committed
13
from vllm.distributed import cleanup_dist_env_and_memory
14
from vllm.platforms import current_platform
15

16
17
MTP_SIMILARITY_RATE = 0.8

18

19
def get_test_prompts(mm_enabled: bool):
20
    prompt_types = ["repeat", "sentence"]
21
22
    if mm_enabled:
        prompt_types.append("mm")
23
24
25
26
27
    num_prompts = 100
    prompts = []

    random.seed(0)
    random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
28
    print(f"Prompt types: {random_prompt_type_choices}")
29
30
31
32
33
34

    # Generate a mixed batch of prompts, some of which can be easily
    # predicted by n-gram matching and some which likely cannot.
    for kind in random_prompt_type_choices:
        word_choices = ["test", "temp", "hello", "where"]
        word = random.choice(word_choices)
35
        prompt: str | list[dict[str, Any]] = ""
36
37
38
39
40
41
42
43
44
45
46
47
        if kind == "repeat":
            prompt = f"""
            please repeat the word '{word}' 10 times.
            give no other output than the word at least ten times in a row,
            in lowercase with spaces between each word and without quotes.
            """
        elif kind == "sentence":
            prompt = f"""
            please give a ten-word sentence that
            uses the word {word} at least once.
            give no other output than that simple sentence without quotes.
            """
48
        elif kind == "mm":
49
50
51
52
53
54
55
56
            placeholders = [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
                    },
                }
            ]
57
58
            prompt = [
                *placeholders,
59
                {"type": "text", "text": "The meaning of the image is"},
60
            ]
61
62
63
64
65
        else:
            raise ValueError(f"Unknown prompt type: {kind}")
        prompts.append([{"role": "user", "content": prompt}])

    return prompts
66
67
68
69


@pytest.fixture
def sampling_config():
70
    return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
71
72
73
74


@pytest.fixture
def model_name():
75
    return "meta-llama/Llama-3.1-8B-Instruct"
76
77


78
79
80
81
82
def test_ngram_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_name: str,
):
83
    """
84
    Compare the outputs of an original LLM and a speculative LLM
85
    should be the same when using ngram speculative decoding.
86
    """
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    test_prompts = get_test_prompts(mm_enabled=False)

    ref_llm = LLM(model=model_name, max_model_len=1024)
    ref_outputs = ref_llm.chat(test_prompts, sampling_config)
    del ref_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    spec_llm = LLM(
        model=model_name,
        speculative_config={
            "method": "ngram",
            "prompt_lookup_max": 5,
            "prompt_lookup_min": 3,
            "num_speculative_tokens": 3,
        },
        max_model_len=1024,
    )
    spec_outputs = spec_llm.chat(test_prompts, sampling_config)
    matches = 0
    misses = 0
    for ref_output, spec_output in zip(ref_outputs, spec_outputs):
        if ref_output.outputs[0].text == spec_output.outputs[0].text:
            matches += 1
        else:
            misses += 1
            print(f"ref_output: {ref_output.outputs[0].text}")
            print(f"spec_output: {spec_output.outputs[0].text}")

    # Heuristic: expect at least 66% of the prompts to match exactly
    # Upon failure, inspect the outputs to check for inaccuracy.
    assert matches >= int(0.66 * len(ref_outputs))
    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
    ["model_setup", "mm_enabled"],
    [
        (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        pytest.param(
            (
                "eagle3",
                "Qwen/Qwen2.5-VL-7B-Instruct",
                "Rayzl/qwen2.5-vl-7b-eagle3-sgl",
                1,
            ),
            False,
            marks=pytest.mark.skip(
                reason="Skipping due to its head_dim not being a a multiple of 32"
            ),
        ),
        (
            (
                "eagle",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
                1,
            ),
            False,
        ),
        (
            (
                "eagle3",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
                1,
            ),
            False,
        ),
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
                4,
            ),
            False,
            marks=large_gpu_mark(min_gb=80),
        ),  # works on 4x H100
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
                4,
            ),
            True,
            marks=large_gpu_mark(min_gb=80),
        ),  # works on 4x H100
        (
            (
                "eagle",
                "eagle618/deepseek-v3-random",
                "eagle618/eagle-deepseek-v3-random",
                1,
            ),
            False,
        ),
187
188
    ],
    ids=[
189
190
191
192
193
194
195
196
197
198
        "qwen3_eagle3",
        "qwen2_5_vl_eagle3",
        "llama3_eagle",
        "llama3_eagle3",
        "llama4_eagle",
        "llama4_eagle_mm",
        "deepseek_eagle",
    ],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
199
200
201
def test_eagle_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
zhiweiz's avatar
zhiweiz committed
202
    model_setup: tuple[str, str, str, int],
203
    mm_enabled: bool,
204
    attn_backend: str,
205
):
206
207
208
209
    if attn_backend == "TREE_ATTN":
        # TODO: Fix this flaky test
        pytest.skip(
            "TREE_ATTN is flaky in the test disable for now until it can be "
210
211
            "resolved (see https://github.com/vllm-project/vllm/issues/22922)"
        )
212

213
214
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
215
    """
216
217
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using eagle speculative decoding.
zhiweiz's avatar
zhiweiz committed
218
    model_setup: (method, model_name, eagle_model_name, tp_size)
219
    """
220
    with monkeypatch.context() as m:
221
222
223
224
225
226
227
228
        if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
            # Scout requires default backend selection
            # because vision encoder has head_dim 88 being incompatible
            #  with FLASH_ATTN and needs to fall back to Flex Attn
            pass
        else:
            m.setenv("VLLM_MLA_DISABLE", "1")
            m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
229

230
231
232
233
234
        if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
            pytest.skip(
                "TRITON_ATTN does not support "
                "multi-token eagle spec decode on current platform"
            )
235

236
        if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
237
238
            m.setenv("VLLM_ROCM_USE_AITER", "1")

zhiweiz's avatar
zhiweiz committed
239
        method, model_name, spec_model_name, tp_size = model_setup
240

241
242
243
        ref_llm = LLM(
            model=model_name, max_model_len=2048, tensor_parallel_size=tp_size
        )
244
245
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
zhiweiz's avatar
zhiweiz committed
246
247
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
248
249
250

        spec_llm = LLM(
            model=model_name,
251
            trust_remote_code=True,
zhiweiz's avatar
zhiweiz committed
252
            tensor_parallel_size=tp_size,
253
            speculative_config={
zhiweiz's avatar
zhiweiz committed
254
                "method": method,
255
                "model": spec_model_name,
256
                "num_speculative_tokens": 3,
257
                "max_model_len": 2048,
258
            },
259
            max_model_len=2048,
260
261
262
263
264
265
266
267
268
269
270
271
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

272
        # Heuristic: expect at least 66% of the prompts to match exactly
273
        # Upon failure, inspect the outputs to check for inaccuracy.
274
        assert matches > int(0.66 * len(ref_outputs))
275
        del spec_llm
zhiweiz's avatar
zhiweiz committed
276
277
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
278
279


280
281
282
283
284
285
286
287
@pytest.mark.parametrize(
    ["model_setup", "mm_enabled"],
    [
        (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
        (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
    ],
    ids=["mimo", "deepseek"],
)
288
289
290
291
292
293
294
295
def test_mtp_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_setup: tuple[str, str, int],
    mm_enabled: bool,
):
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
296
    """
297
298
299
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using MTP speculative decoding.
    model_setup: (method, model_name, tp_size)
300
    """
301
302
303
304
305
    with monkeypatch.context() as m:
        m.setenv("VLLM_MLA_DISABLE", "1")

        method, model_name, tp_size = model_setup

306
307
308
309
310
311
        ref_llm = LLM(
            model=model_name,
            max_model_len=2048,
            tensor_parallel_size=tp_size,
            trust_remote_code=True,
        )
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()

        spec_llm = LLM(
            model=model_name,
            trust_remote_code=True,
            tensor_parallel_size=tp_size,
            speculative_config={
                "method": method,
                "num_speculative_tokens": 1,
                "max_model_len": 2048,
            },
            max_model_len=2048,
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

        # Heuristic: expect at least 80% of the prompts to match exactly
        # Upon failure, inspect the outputs to check for inaccuracy.
        assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
        del spec_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()