test_spec_decode.py 10.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from __future__ import annotations

5
import random
6
from typing import Any, Union
7

8
import pytest
zhiweiz's avatar
zhiweiz committed
9
import torch
10

11
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
12
from vllm import LLM, SamplingParams
13
14
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
zhiweiz's avatar
zhiweiz committed
15
from vllm.distributed import cleanup_dist_env_and_memory
16
from vllm.platforms import current_platform
17

18
19
MTP_SIMILARITY_RATE = 0.8

20

21
def get_test_prompts(mm_enabled: bool):
22
    prompt_types = ["repeat", "sentence"]
23
24
    if mm_enabled:
        prompt_types.append("mm")
25
26
27
28
29
    num_prompts = 100
    prompts = []

    random.seed(0)
    random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
30
    print(f"Prompt types: {random_prompt_type_choices}")
31
32
33
34
35
36

    # Generate a mixed batch of prompts, some of which can be easily
    # predicted by n-gram matching and some which likely cannot.
    for kind in random_prompt_type_choices:
        word_choices = ["test", "temp", "hello", "where"]
        word = random.choice(word_choices)
37
        prompt: Union[str, list[dict[str, Any]]] = ""
38
39
40
41
42
43
44
45
46
47
48
49
        if kind == "repeat":
            prompt = f"""
            please repeat the word '{word}' 10 times.
            give no other output than the word at least ten times in a row,
            in lowercase with spaces between each word and without quotes.
            """
        elif kind == "sentence":
            prompt = f"""
            please give a ten-word sentence that
            uses the word {word} at least once.
            give no other output than that simple sentence without quotes.
            """
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        elif kind == "mm":
            placeholders = [{
                "type": "image_url",
                "image_url": {
                    "url":
                    f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
                },
            }]
            prompt = [
                *placeholders,
                {
                    "type": "text",
                    "text": "The meaning of the image is"
                },
            ]
65
66
67
68
69
        else:
            raise ValueError(f"Unknown prompt type: {kind}")
        prompts.append([{"role": "user", "content": prompt}])

    return prompts
70
71
72
73


@pytest.fixture
def sampling_config():
74
    return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
75
76
77
78


@pytest.fixture
def model_name():
79
    return "meta-llama/Llama-3.1-8B-Instruct"
80
81


82
83
84
85
86
def test_ngram_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_name: str,
):
87
    '''
88
    Compare the outputs of an original LLM and a speculative LLM
89
90
    should be the same when using ngram speculative decoding.
    '''
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    test_prompts = get_test_prompts(mm_enabled=False)

    ref_llm = LLM(model=model_name, max_model_len=1024)
    ref_outputs = ref_llm.chat(test_prompts, sampling_config)
    del ref_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    spec_llm = LLM(
        model=model_name,
        speculative_config={
            "method": "ngram",
            "prompt_lookup_max": 5,
            "prompt_lookup_min": 3,
            "num_speculative_tokens": 3,
        },
        max_model_len=1024,
    )
    spec_outputs = spec_llm.chat(test_prompts, sampling_config)
    matches = 0
    misses = 0
    for ref_output, spec_output in zip(ref_outputs, spec_outputs):
        if ref_output.outputs[0].text == spec_output.outputs[0].text:
            matches += 1
        else:
            misses += 1
            print(f"ref_output: {ref_output.outputs[0].text}")
            print(f"spec_output: {spec_output.outputs[0].text}")

    # Heuristic: expect at least 66% of the prompts to match exactly
    # Upon failure, inspect the outputs to check for inaccuracy.
    assert matches >= int(0.66 * len(ref_outputs))
    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
    ["model_setup", "mm_enabled"],
    [
        (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
132
133
134
135
136
        pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
                      "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
                     False,
                     marks=pytest.mark.skip(reason="Skipping due to its " \
                               "head_dim not being a a multiple of 32")),
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        (("eagle", "meta-llama/Llama-3.1-8B-Instruct",
          "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
        (("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
          "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
        pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                      "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
                     False,
                     marks=large_gpu_mark(min_gb=80)),  # works on 4x H100
        pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
                      "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
                     True,
                     marks=large_gpu_mark(min_gb=80)),  # works on 4x H100
        (("eagle", "eagle618/deepseek-v3-random",
          "eagle618/eagle-deepseek-v3-random", 1), False),
    ],
    ids=[
153
154
        "qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
        "llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
155
    ])
156
157
@pytest.mark.parametrize("attn_backend",
                         get_attn_backend_list_based_on_platform())
158
159
160
def test_eagle_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
zhiweiz's avatar
zhiweiz committed
161
    model_setup: tuple[str, str, str, int],
162
    mm_enabled: bool,
163
    attn_backend: str,
164
):
165
166
167
168
    if attn_backend == "TREE_ATTN":
        # TODO: Fix this flaky test
        pytest.skip(
            "TREE_ATTN is flaky in the test disable for now until it can be "
co63oc's avatar
co63oc committed
169
            "resolved (see https://github.com/vllm-project/vllm/issues/22922)")
170

171
172
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
173
174
175
    '''
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using eagle speculative decoding.
zhiweiz's avatar
zhiweiz committed
176
    model_setup: (method, model_name, eagle_model_name, tp_size)
177
178
    '''
    with monkeypatch.context() as m:
179
180
181
182
183
184
185
186
        if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
            # Scout requires default backend selection
            # because vision encoder has head_dim 88 being incompatible
            #  with FLASH_ATTN and needs to fall back to Flex Attn
            pass
        else:
            m.setenv("VLLM_MLA_DISABLE", "1")
            m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
187

188
189
        if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
            pytest.skip("TRITON_ATTN does not support "
190
191
                        "multi-token eagle spec decode on current platform")

192
        if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
193
194
            m.setenv("VLLM_ROCM_USE_AITER", "1")

zhiweiz's avatar
zhiweiz committed
195
        method, model_name, spec_model_name, tp_size = model_setup
196

zhiweiz's avatar
zhiweiz committed
197
198
199
        ref_llm = LLM(model=model_name,
                      max_model_len=2048,
                      tensor_parallel_size=tp_size)
200
201
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
zhiweiz's avatar
zhiweiz committed
202
203
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
204
205
206

        spec_llm = LLM(
            model=model_name,
207
            trust_remote_code=True,
zhiweiz's avatar
zhiweiz committed
208
            tensor_parallel_size=tp_size,
209
            speculative_config={
zhiweiz's avatar
zhiweiz committed
210
                "method": method,
211
                "model": spec_model_name,
212
                "num_speculative_tokens": 3,
213
                "max_model_len": 2048,
214
            },
215
            max_model_len=2048,
216
217
218
219
220
221
222
223
224
225
226
227
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

228
        # Heuristic: expect at least 66% of the prompts to match exactly
229
        # Upon failure, inspect the outputs to check for inaccuracy.
230
        assert matches > int(0.66 * len(ref_outputs))
231
        del spec_llm
zhiweiz's avatar
zhiweiz committed
232
233
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296


@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
    (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
    (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
],
                         ids=["mimo", "deepseek"])
def test_mtp_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_setup: tuple[str, str, int],
    mm_enabled: bool,
):
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
    '''
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using MTP speculative decoding.
    model_setup: (method, model_name, tp_size)
    '''
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        m.setenv("VLLM_MLA_DISABLE", "1")

        method, model_name, tp_size = model_setup

        ref_llm = LLM(model=model_name,
                      max_model_len=2048,
                      tensor_parallel_size=tp_size,
                      trust_remote_code=True)
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()

        spec_llm = LLM(
            model=model_name,
            trust_remote_code=True,
            tensor_parallel_size=tp_size,
            speculative_config={
                "method": method,
                "num_speculative_tokens": 1,
                "max_model_len": 2048,
            },
            max_model_len=2048,
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

        # Heuristic: expect at least 80% of the prompts to match exactly
        # Upon failure, inspect the outputs to check for inaccuracy.
        assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
        del spec_llm
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()