test_spec_decode.py 6.87 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from __future__ import annotations

5
import random
6
from typing import Any, Union
7

8
import pytest
zhiweiz's avatar
zhiweiz committed
9
import torch
10
11

from vllm import LLM, SamplingParams
12
13
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
zhiweiz's avatar
zhiweiz committed
14
from vllm.distributed import cleanup_dist_env_and_memory
15
16


17
def get_test_prompts(mm_enabled: bool):
18
    prompt_types = ["repeat", "sentence"]
19
20
    if mm_enabled:
        prompt_types.append("mm")
21
22
23
24
25
    num_prompts = 100
    prompts = []

    random.seed(0)
    random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
26
    print(f"Prompt types: {random_prompt_type_choices}")
27
28
29
30
31
32

    # Generate a mixed batch of prompts, some of which can be easily
    # predicted by n-gram matching and some which likely cannot.
    for kind in random_prompt_type_choices:
        word_choices = ["test", "temp", "hello", "where"]
        word = random.choice(word_choices)
33
        prompt: Union[str, list[dict[str, Any]]] = ""
34
35
36
37
38
39
40
41
42
43
44
45
        if kind == "repeat":
            prompt = f"""
            please repeat the word '{word}' 10 times.
            give no other output than the word at least ten times in a row,
            in lowercase with spaces between each word and without quotes.
            """
        elif kind == "sentence":
            prompt = f"""
            please give a ten-word sentence that
            uses the word {word} at least once.
            give no other output than that simple sentence without quotes.
            """
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        elif kind == "mm":
            placeholders = [{
                "type": "image_url",
                "image_url": {
                    "url":
                    f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
                },
            }]
            prompt = [
                *placeholders,
                {
                    "type": "text",
                    "text": "The meaning of the image is"
                },
            ]
61
62
63
64
65
        else:
            raise ValueError(f"Unknown prompt type: {kind}")
        prompts.append([{"role": "user", "content": prompt}])

    return prompts
66
67
68
69


@pytest.fixture
def sampling_config():
70
    return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
71
72
73
74


@pytest.fixture
def model_name():
75
    return "meta-llama/Llama-3.1-8B-Instruct"
76
77


78
79
80
81
82
def test_ngram_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
    model_name: str,
):
83
84
85
86
87
88
    '''
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using ngram speculative decoding.
    '''
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
89
        test_prompts = get_test_prompts(mm_enabled=False)
90

91
92
        ref_llm = LLM(model=model_name, max_model_len=1024)
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
93
        del ref_llm
zhiweiz's avatar
zhiweiz committed
94
95
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
96

97
98
99
100
101
102
103
104
105
106
        spec_llm = LLM(
            model=model_name,
            speculative_config={
                "method": "ngram",
                "prompt_lookup_max": 5,
                "prompt_lookup_min": 3,
                "num_speculative_tokens": 3,
            },
            max_model_len=1024,
        )
107
108
109
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
110
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
111
112
113
114
115
116
117
118
119
120
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

        # Heuristic: expect at least 70% of the prompts to match exactly
        # Upon failure, inspect the outputs to check for inaccuracy.
        assert matches > int(0.7 * len(ref_outputs))
121
        del spec_llm
zhiweiz's avatar
zhiweiz committed
122
123
124
125
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@pytest.mark.parametrize(
    ["model_setup", "mm_enabled"], [
        (("eagle", "meta-llama/Llama-3.1-8B-Instruct",
          "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
        (("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
          "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
        pytest.param(
            ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
             "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
            False,
            marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
        pytest.param(
            ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
             "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
            True,
            marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
    ],
    ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
144
145
146
def test_eagle_correctness(
    monkeypatch: pytest.MonkeyPatch,
    sampling_config: SamplingParams,
zhiweiz's avatar
zhiweiz committed
147
    model_setup: tuple[str, str, str, int],
148
    mm_enabled: bool,
149
):
150
151
    # Generate test prompts inside the function instead of using fixture
    test_prompts = get_test_prompts(mm_enabled)
152
153
154
    '''
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using eagle speculative decoding.
zhiweiz's avatar
zhiweiz committed
155
    model_setup: (method, model_name, eagle_model_name, tp_size)
156
157
158
    '''
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
zhiweiz's avatar
zhiweiz committed
159
        method, model_name, spec_model_name, tp_size = model_setup
160

zhiweiz's avatar
zhiweiz committed
161
162
163
        ref_llm = LLM(model=model_name,
                      max_model_len=2048,
                      tensor_parallel_size=tp_size)
164
165
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
zhiweiz's avatar
zhiweiz committed
166
167
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
168
169
170

        spec_llm = LLM(
            model=model_name,
171
            trust_remote_code=True,
zhiweiz's avatar
zhiweiz committed
172
            tensor_parallel_size=tp_size,
173
            speculative_config={
zhiweiz's avatar
zhiweiz committed
174
                "method": method,
175
                "model": spec_model_name,
176
                "num_speculative_tokens": 3,
177
                "max_model_len": 2048,
178
            },
179
            max_model_len=2048,
180
181
182
183
184
185
186
187
188
189
190
191
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

192
        # Heuristic: expect at least 66% of the prompts to match exactly
193
        # Upon failure, inspect the outputs to check for inaccuracy.
194
        assert matches > int(0.66 * len(ref_outputs))
195
        del spec_llm
zhiweiz's avatar
zhiweiz committed
196
197
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()