test_spec_decode.py 8.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from __future__ import annotations

5
import random
6
from typing import Any, Union
7

8
import os
9
import pytest
zhiweiz's avatar
zhiweiz committed
10
import torch
11

12
from tests.utils import get_attn_backend_list_based_on_platform
13
from vllm import LLM, SamplingParams
14

15
16
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
17

zhiweiz's avatar
zhiweiz committed
18
from vllm.distributed import cleanup_dist_env_and_memory
19
from vllm.platforms import current_platform
20
from ...utils import models_path_prefix
21
22


23
def get_test_prompts(mm_enabled: bool):
24
    prompt_types = ["repeat", "sentence"]
25
26
    if mm_enabled:
        prompt_types.append("mm")
27
28
29
30
31
    num_prompts = 100
    prompts = []

    random.seed(0)
    random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
32
    print(f"Prompt types: {random_prompt_type_choices}")
33
34
35
36
37
38

    # Generate a mixed batch of prompts, some of which can be easily
    # predicted by n-gram matching and some which likely cannot.
    for kind in random_prompt_type_choices:
        word_choices = ["test", "temp", "hello", "where"]
        word = random.choice(word_choices)
39
        prompt: Union[str, list[dict[str, Any]]] = ""
40
41
42
43
44
45
46
47
48
49
50
51
        if kind == "repeat":
            prompt = f"""
            please repeat the word '{word}' 10 times.
            give no other output than the word at least ten times in a row,
            in lowercase with spaces between each word and without quotes.
            """
        elif kind == "sentence":
            prompt = f"""
            please give a ten-word sentence that
            uses the word {word} at least once.
            give no other output than that simple sentence without quotes.
            """
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        elif kind == "mm":
            placeholders = [{
                "type": "image_url",
                "image_url": {
                    "url":
                    f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
                },
            }]
            prompt = [
                *placeholders,
                {
                    "type": "text",
                    "text": "The meaning of the image is"
                },
            ]
67
68
69
70
71
        else:
            raise ValueError(f"Unknown prompt type: {kind}")
        prompts.append([{"role": "user", "content": prompt}])

    return prompts
72
73
74
75


@pytest.fixture
def sampling_config():
76
    return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
77
78
79
80


@pytest.fixture
def model_name():
81
    # return os.path.join(models_path_prefix, "meta-llama/Llama-3.1-8B-Instruct")
zhuwenwen's avatar
zhuwenwen committed
82
    return os.path.join(models_path_prefix, "meta-llama/Llama-3.1-8B-Instruct")
83
84


85
86
87
88
89
def test_ngram_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_name: str,
):
90
91
92
93
94
95
    '''
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using ngram speculative decoding.
    '''
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
96
        test_prompts = get_test_prompts(mm_enabled=False)
97

98
99
        ref_llm = LLM(model=model_name, max_model_len=1024)
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
100
        del ref_llm
zhiweiz's avatar
zhiweiz committed
101
102
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
103

104
105
106
107
108
109
110
111
112
113
        spec_llm = LLM(
            model=model_name,
            speculative_config={
                "method": "ngram",
                "prompt_lookup_max": 5,
                "prompt_lookup_min": 3,
                "num_speculative_tokens": 3,
            },
            max_model_len=1024,
        )
114
115
116
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
117
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
118
119
120
121
122
123
124
125
126
127
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

        # Heuristic: expect at least 70% of the prompts to match exactly
        # Upon failure, inspect the outputs to check for inaccuracy.
        assert matches > int(0.7 * len(ref_outputs))
128
        del spec_llm
zhiweiz's avatar
zhiweiz committed
129
130
131
132
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()


133
134
135
136
137
@pytest.mark.parametrize(
    ["model_setup", "mm_enabled"],
    [
        # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611  # noqa: E501
        # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
zhuwenwen's avatar
zhuwenwen committed
138
        (("eagle", os.path.join(models_path_prefix,"meta-llama/Llama-3.1-8B-Instruct"),
139
          "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
zhuwenwen's avatar
zhuwenwen committed
140
        (("eagle3",  os.path.join(models_path_prefix, "meta-llama/Llama-3.1-8B-Instruct"),
141
142
          "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
        pytest.param(
zhuwenwen's avatar
zhuwenwen committed
143
144
            ("eagle",  os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
              os.path.join(models_path_prefix, "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"), 4),
145
146
147
            False,
            marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
        pytest.param(
zhuwenwen's avatar
zhuwenwen committed
148
149
            ("eagle",  os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
              os.path.join(models_path_prefix, "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"), 4),
150
151
152
153
            True,
            marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
    ],
    ids=[
154
155
156
157
158
        # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611  # noqa: E501
        # "qwen3_eagle3",
        "llama3_eagle",
        "llama3_eagle3",
        "llama4_eagle",
159
160
        "llama4_eagle_mm"
    ])
161
162
@pytest.mark.parametrize("attn_backend",
                         get_attn_backend_list_based_on_platform())
163
164
165
def test_eagle_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
zhiweiz's avatar
zhiweiz committed
166
    model_setup: tuple[str, str, str, int],
167
    mm_enabled: bool,
168
    attn_backend: str,
169
):
170
171
172
173
174
175
    if attn_backend == "TREE_ATTN":
        # TODO: Fix this flaky test
        pytest.skip(
            "TREE_ATTN is flaky in the test disable for now until it can be "
            "reolved (see https://github.com/vllm-project/vllm/issues/22922)")

176
177
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
178
179
180
    '''
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using eagle speculative decoding.
zhiweiz's avatar
zhiweiz committed
181
    model_setup: (method, model_name, eagle_model_name, tp_size)
182
183
184
    '''
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
185
186
187
188
189
190
191
192
193
194
        m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

        if (attn_backend == "TRITON_ATTN_VLLM_V1"
                and not current_platform.is_rocm()):
            pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
                        "multi-token eagle spec decode on current platform")

        if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
            m.setenv("VLLM_ROCM_USE_AITER", "1")

zhiweiz's avatar
zhiweiz committed
195
        method, model_name, spec_model_name, tp_size = model_setup
196

zhiweiz's avatar
zhiweiz committed
197
198
199
        ref_llm = LLM(model=model_name,
                      max_model_len=2048,
                      tensor_parallel_size=tp_size)
200
201
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
zhiweiz's avatar
zhiweiz committed
202
203
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
204
205
206

        spec_llm = LLM(
            model=model_name,
207
            trust_remote_code=True,
zhiweiz's avatar
zhiweiz committed
208
            tensor_parallel_size=tp_size,
209
            speculative_config={
zhiweiz's avatar
zhiweiz committed
210
                "method": method,
211
                "model": spec_model_name,
212
                "num_speculative_tokens": 3,
213
                "max_model_len": 2048,
214
            },
215
            max_model_len=2048,
216
217
218
219
220
221
222
223
224
225
226
227
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

228
        # Heuristic: expect at least 66% of the prompts to match exactly
229
        # Upon failure, inspect the outputs to check for inaccuracy.
230
        assert matches > int(0.66 * len(ref_outputs))
231
        del spec_llm
zhiweiz's avatar
zhiweiz committed
232
233
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()