test_spec_decode.py 5.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from __future__ import annotations

5
import random
6
from typing import Any
7

8
import os
9
import pytest
zhiweiz's avatar
zhiweiz committed
10
import torch
11
12

from vllm import LLM, SamplingParams
13
from ...utils import models_path_prefix
zhiweiz's avatar
zhiweiz committed
14
from vllm.distributed import cleanup_dist_env_and_memory
15
16
17
18


@pytest.fixture
def test_prompts():
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    prompt_types = ["repeat", "sentence"]
    num_prompts = 100
    prompts = []

    random.seed(0)
    random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)

    # Generate a mixed batch of prompts, some of which can be easily
    # predicted by n-gram matching and some which likely cannot.
    for kind in random_prompt_type_choices:
        word_choices = ["test", "temp", "hello", "where"]
        word = random.choice(word_choices)
        if kind == "repeat":
            prompt = f"""
            please repeat the word '{word}' 10 times.
            give no other output than the word at least ten times in a row,
            in lowercase with spaces between each word and without quotes.
            """
        elif kind == "sentence":
            prompt = f"""
            please give a ten-word sentence that
            uses the word {word} at least once.
            give no other output than that simple sentence without quotes.
            """
        else:
            raise ValueError(f"Unknown prompt type: {kind}")
        prompts.append([{"role": "user", "content": prompt}])

    return prompts
48
49
50
51


@pytest.fixture
def sampling_config():
52
    return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
53
54
55
56


@pytest.fixture
def model_name():
57
    # return os.path.join(models_path_prefix, "meta-llama/Llama-3.1-8B-Instruct")
58
    return "meta-llama/Llama-3.1-8B-Instruct"
59
60


61
62
63
64
65
66
def test_ngram_correctness(
    monkeypatch: pytest.MonkeyPatch,
    test_prompts: list[list[dict[str, Any]]],
    sampling_config: SamplingParams,
    model_name: str,
):
67
68
69
70
71
72
73
    '''
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using ngram speculative decoding.
    '''
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

74
75
        ref_llm = LLM(model=model_name, max_model_len=1024)
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
76
        del ref_llm
zhiweiz's avatar
zhiweiz committed
77
78
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
79

80
81
82
83
84
85
86
87
88
89
        spec_llm = LLM(
            model=model_name,
            speculative_config={
                "method": "ngram",
                "prompt_lookup_max": 5,
                "prompt_lookup_min": 3,
                "num_speculative_tokens": 3,
            },
            max_model_len=1024,
        )
90
91
92
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
93
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
94
95
96
97
98
99
100
101
102
103
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

        # Heuristic: expect at least 70% of the prompts to match exactly
        # Upon failure, inspect the outputs to check for inaccuracy.
        assert matches > int(0.7 * len(ref_outputs))
104
        del spec_llm
zhiweiz's avatar
zhiweiz committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()


@pytest.mark.parametrize("model_setup", [
    ("eagle", "meta-llama/Llama-3.1-8B-Instruct",
     "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
    ("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
     "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
    pytest.param(
        ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
         "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
        marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
                         ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
120
121
122
123
def test_eagle_correctness(
    monkeypatch: pytest.MonkeyPatch,
    test_prompts: list[list[dict[str, Any]]],
    sampling_config: SamplingParams,
zhiweiz's avatar
zhiweiz committed
124
    model_setup: tuple[str, str, str, int],
125
126
127
128
):
    '''
    Compare the outputs of a original LLM and a speculative LLM
    should be the same when using eagle speculative decoding.
zhiweiz's avatar
zhiweiz committed
129
    model_setup: (method, model_name, eagle_model_name, tp_size)
130
131
132
    '''
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
zhiweiz's avatar
zhiweiz committed
133
        method, model_name, spec_model_name, tp_size = model_setup
134

zhiweiz's avatar
zhiweiz committed
135
136
137
        ref_llm = LLM(model=model_name,
                      max_model_len=2048,
                      tensor_parallel_size=tp_size)
138
139
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
        del ref_llm
zhiweiz's avatar
zhiweiz committed
140
141
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
142
143
144

        spec_llm = LLM(
            model=model_name,
145
            trust_remote_code=True,
zhiweiz's avatar
zhiweiz committed
146
            tensor_parallel_size=tp_size,
147
            speculative_config={
zhiweiz's avatar
zhiweiz committed
148
                "method": method,
149
                "model": spec_model_name,
150
                "num_speculative_tokens": 3,
151
                "max_model_len": 2048,
152
            },
153
            max_model_len=2048,
154
155
156
157
158
159
160
161
162
163
164
165
        )
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
        matches = 0
        misses = 0
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
                matches += 1
            else:
                misses += 1
                print(f"ref_output: {ref_output.outputs[0].text}")
                print(f"spec_output: {spec_output.outputs[0].text}")

166
        # Heuristic: expect at least 66% of the prompts to match exactly
167
        # Upon failure, inspect the outputs to check for inaccuracy.
168
        assert matches > int(0.66 * len(ref_outputs))
169
        del spec_llm
zhiweiz's avatar
zhiweiz committed
170
171
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()