conftest.py 8.47 KB
Newer Older
1
from itertools import cycle
2
from typing import List, Optional, Tuple
3

4
import pytest
5

6
from vllm import LLM, SamplingParams
7
from vllm.model_executor.utils import set_random_seed
8

9
from ...conftest import cleanup
10
11
from ...models.utils import check_logprobs_close, check_outputs_equal
from ...utils import RemoteOpenAIServer
12

13
14
15
16
17
18
19
20
21
22
PROMPTS = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
    "San Francisco is know for its",
    "Facebook was created in 2004 by",
    "Curious George is a",
    "Python 3.11 brings improvements to its",
]
23
24
25


@pytest.fixture
26
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
27
28
                       test_llm_kwargs, seed):

29
30
31
32
33
34
35
36
    def generate():
        kwargs = {
            **common_llm_kwargs,
            **per_test_common_llm_kwargs,
            **test_llm_kwargs,
        }

        llm = LLM(**kwargs)
37

38
39
        if seed is not None:
            set_random_seed(seed)
40
41

        yield llm
42

43
44
45
        del llm
        cleanup()

46
    return generate
47
48


49
50
def maybe_assert_ngram_worker(llm):
    # Verify the proposer worker is ngram if ngram is specified.
51
    if (llm.llm_engine.speculative_config is not None
52
53
54
55
56
57
58
            and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
        from vllm.spec_decode.ngram_worker import NGramWorker
        assert isinstance(
            llm.llm_engine.model_executor.driver_worker.proposer_worker,
            NGramWorker)


59
60
def get_output_from_llm_generator(
        llm_generator, prompts,
61
        sampling_params) -> Tuple[List[str], List[List[int]], float]:
62
63
    tokens: List[str] = []
    token_ids: List[List[int]] = []
64
    acceptance_rate: float = -1.0
65
    for llm in llm_generator():
66
67
        maybe_assert_ngram_worker(llm)

68
        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
69

70
71
        token_ids = [output.outputs[0].token_ids for output in outputs]
        tokens = [output.outputs[0].text for output in outputs]
72
73
74
75
76
77
78

        # Fetch acceptance rate if logging is enabled.
        if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
            stat_logger = stat_loggers["prometheus"]
            acceptance_rate = (stat_logger.metrics.
                               gauge_spec_decode_draft_acceptance_rate.labels(
                                   **stat_logger.labels)._value.get())
79
        del llm
80

81
    return tokens, token_ids, acceptance_rate
82
83


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def run_logprob_correctness_test(vllm_runner,
                                 common_llm_kwargs,
                                 per_test_common_llm_kwargs,
                                 baseline_llm_kwargs,
                                 test_llm_kwargs,
                                 batch_size: int,
                                 max_output_len: int,
                                 seed: Optional[int] = 0,
                                 temperature: float = 0.0,
                                 logprobs: int = 1):
    org_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **baseline_llm_kwargs,
    }
99

100
101
102
103
104
    sd_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **test_llm_kwargs,
    }
105

106
    prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
107

108
109
110
111
112
113
114
    sampling_params = SamplingParams(temperature=temperature,
                                     max_tokens=max_output_len,
                                     seed=seed,
                                     logprobs=logprobs)

    with vllm_runner(**org_args) as vllm_model:
        org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
115

116
117
118
119
120
121
122
    with vllm_runner(**sd_args) as vllm_model:
        sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)

    check_logprobs_close(outputs_0_lst=org_outputs,
                         outputs_1_lst=sd_outputs,
                         name_0="org",
                         name_1="sd")
123
124


125
def run_equality_correctness_test(
126
127
128
129
130
131
132
133
134
135
136
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size: int,
        max_output_len: int,
        seed: Optional[int] = 0,
        temperature: float = 0.0,
        disable_seed: bool = False,
        ignore_eos: bool = True,
137
138
        ensure_all_accepted: bool = False,
        expected_acceptance_rate: Optional[float] = None):
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

    org_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **baseline_llm_kwargs,
    }

    sd_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **test_llm_kwargs,
    }

    prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]

    if disable_seed:
        seed = None

    sampling_params = SamplingParams(temperature=temperature,
                                     max_tokens=max_output_len,
                                     seed=seed,
                                     ignore_eos=ignore_eos)

    with vllm_runner(**org_args) as vllm_model:
        org_outputs = vllm_model.generate(prompts, sampling_params)

    with vllm_runner(**sd_args) as vllm_model:
        if ensure_all_accepted or expected_acceptance_rate is not None:
            # Force log interval to be 0 to catch all metrics.
            stat_logger = vllm_model.model.llm_engine.stat_loggers[
                'prometheus']
            stat_logger.local_interval = -100

        sd_outputs = vllm_model.generate(prompts, sampling_params)

        if ensure_all_accepted or expected_acceptance_rate is not None:
            acceptance_rate = (stat_logger.metrics.
                               gauge_spec_decode_draft_acceptance_rate.labels(
                                   **stat_logger.labels)._value.get())

            if ensure_all_accepted:
                assert True
                # FIXME: ci fails to log acceptance rate.
                # It works locally.
                # assert acceptance_rate == 1.0

            if expected_acceptance_rate is not None:
                assert acceptance_rate >= expected_acceptance_rate - 1e-2

    check_outputs_equal(outputs_0_lst=org_outputs,
                        outputs_1_lst=sd_outputs,
                        name_0="org",
                        name_1="sd")


def run_equality_correctness_test_tp(model,
                                     common_llm_kwargs,
                                     per_test_common_llm_kwargs,
                                     baseline_llm_kwargs,
                                     test_llm_kwargs,
                                     batch_size: int,
                                     max_output_len: int,
                                     seed: int = 0,
                                     temperature: float = 0.0):
203
204
    """Helper method that compares the outputs of both the baseline LLM and
    the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
205
    the same when temperature is zero.
206
    """
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
    arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
    env1 = env2 = None

    max_wait_seconds = 240
    results = []

    prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]

    for args, env in ((arg1, env1), (arg2, env2)):
        with RemoteOpenAIServer(model,
                                args,
                                env_dict=env,
                                max_wait_seconds=max_wait_seconds) as server:
            client = server.get_client()

            completion = client.completions.create(model=model,
                                                   prompt=prompts,
                                                   max_tokens=max_output_len,
                                                   seed=seed,
                                                   temperature=temperature)

            results.append({
                "test":
                "seeded_sampling",
                "text": [choice.text for choice in completion.choices],
                "finish_reason":
                [choice.finish_reason for choice in completion.choices],
                "usage":
                completion.usage,
            })

    n = len(results) // 2
    arg1_results = results[:n]
    arg2_results = results[n:]
    for arg1_result, arg2_result in zip(arg1_results, arg2_results):
        assert arg1_result == arg2_result, (
            f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
            f"{arg1_result=} != {arg2_result=}")