conftest.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Sequence
5
from itertools import cycle
6
from typing import Optional, Union
7

8
import pytest
9
import torch
10

11
from vllm import LLM, SamplingParams
12
from vllm.distributed import cleanup_dist_env_and_memory
13
from vllm.model_executor.utils import set_random_seed
14
from vllm.sequence import PromptLogprobs, SampleLogprobs
15

16
17
18
from ...models.utils import (TokensTextLogprobs,
                             TokensTextLogprobsPromptLogprobs,
                             check_logprobs_close, check_outputs_equal)
19
from ...utils import RemoteOpenAIServer
20

21
22
23
24
25
26
27
28
29
30
PROMPTS = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
    "San Francisco is know for its",
    "Facebook was created in 2004 by",
    "Curious George is a",
    "Python 3.11 brings improvements to its",
]
31
32
33


@pytest.fixture
34
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
35
36
                       test_llm_kwargs, seed):

37
38
39
40
41
42
43
44
    def generate():
        kwargs = {
            **common_llm_kwargs,
            **per_test_common_llm_kwargs,
            **test_llm_kwargs,
        }

        llm = LLM(**kwargs)
45

46
47
        if seed is not None:
            set_random_seed(seed)
48
49

        yield llm
50

51
        del llm
52
        cleanup_dist_env_and_memory()
53

54
    return generate
55
56


57
58
def maybe_assert_ngram_worker(llm):
    # Verify the proposer worker is ngram if ngram is specified.
59
    if (llm.llm_engine.speculative_config is not None
60
            and llm.llm_engine.speculative_config.method == "ngram"):
61
62
63
64
65
66
        from vllm.spec_decode.ngram_worker import NGramWorker
        assert isinstance(
            llm.llm_engine.model_executor.driver_worker.proposer_worker,
            NGramWorker)


67
68
def get_output_from_llm_generator(
        llm_generator, prompts,
69
70
71
        sampling_params) -> tuple[list[str], list[list[int]], float]:
    tokens: list[str] = []
    token_ids: list[list[int]] = []
72
    acceptance_rate: float = -1.0
73
    for llm in llm_generator():
74
75
        maybe_assert_ngram_worker(llm)

76
        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
77

78
79
        token_ids = [output.outputs[0].token_ids for output in outputs]
        tokens = [output.outputs[0].text for output in outputs]
80
81
82
83
84
85
86

        # Fetch acceptance rate if logging is enabled.
        if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
            stat_logger = stat_loggers["prometheus"]
            acceptance_rate = (stat_logger.metrics.
                               gauge_spec_decode_draft_acceptance_rate.labels(
                                   **stat_logger.labels)._value.get())
87
        del llm
88

89
    return tokens, token_ids, acceptance_rate
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def check_logprobs_correctness(
    spec_outputs: Sequence[Union[TokensTextLogprobs,
                                 TokensTextLogprobsPromptLogprobs]],
    baseline_outputs: Sequence[Union[TokensTextLogprobs,
                                     TokensTextLogprobsPromptLogprobs]],
    disable_logprobs: bool = False,
):
    """Compare sampled and prompt logprobs between baseline and spec decoding
    """
    if not disable_logprobs:
        return check_logprobs_close(
            outputs_0_lst=baseline_outputs,
            outputs_1_lst=spec_outputs,
            name_0="org",
            name_1="sd",
        )

    # Check correctness when disable_logprobs == True
    for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
        # Check generated token logprobs.
        spec_logprobs = spec_output[2]
        baseline_logprobs = baseline_output[2]
        _check_logprobs_when_output_disabled(spec_logprobs,
                                             baseline_logprobs,
                                             is_prompt_logprobs=False)

        # Check prompt logprobs too, if they exist
        if len(baseline_output) == 4:
            assert len(spec_output) == 4
            spec_prompt_logprobs = spec_output[3]
            baseline_prompt_logprobs = baseline_output[3]
            _check_logprobs_when_output_disabled(spec_prompt_logprobs,
                                                 baseline_prompt_logprobs,
                                                 is_prompt_logprobs=True)


def _check_logprobs_when_output_disabled(
    spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
    baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
    is_prompt_logprobs: bool = False,
):
    # Prompt logprobs are optional
    if is_prompt_logprobs and baseline_logprobs is None:
        assert spec_logprobs is None
        return

    assert spec_logprobs is not None
    assert baseline_logprobs is not None
    assert len(spec_logprobs) == len(baseline_logprobs)

    # For each generated position of the sequence.
    for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
            zip(spec_logprobs, baseline_logprobs)):

        # First prompt logprob is expected to be None
        if is_prompt_logprobs and baseline_pos_logprobs is None:
            assert spec_pos_logprobs is None
            assert pos == 0
            continue

        assert spec_pos_logprobs is not None
        assert baseline_pos_logprobs is not None

        # When disabled, the 1 logprob is returned with dummy values for the
        # score and rank, but the token id should match the baseline model
        assert len(spec_pos_logprobs) == 1
        (spec_pos_logprob_token_id,
         spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
        assert spec_pos_logprob.rank == -1
        assert spec_pos_logprob.logprob == 0.0
162
163
        if isinstance(spec_pos_logprob_token_id, torch.Tensor):
            spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
164
        assert spec_pos_logprob_token_id in baseline_pos_logprobs
165
166


167
def run_equality_correctness_test(
168
169
170
171
172
173
174
175
176
177
178
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size: int,
        max_output_len: int,
        seed: Optional[int] = 0,
        temperature: float = 0.0,
        disable_seed: bool = False,
        ignore_eos: bool = True,
179
        ensure_all_accepted: bool = False,
180
181
182
183
        expected_acceptance_rate: Optional[float] = None,
        logprobs: Optional[int] = None,
        prompt_logprobs: Optional[int] = None,
        disable_logprobs: bool = False):
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    org_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **baseline_llm_kwargs,
    }

    sd_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **test_llm_kwargs,
    }

    prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]

    if disable_seed:
        seed = None

    sampling_params = SamplingParams(temperature=temperature,
                                     max_tokens=max_output_len,
                                     seed=seed,
205
206
207
                                     ignore_eos=ignore_eos,
                                     logprobs=logprobs,
                                     prompt_logprobs=prompt_logprobs)
208
209

    with vllm_runner(**org_args) as vllm_model:
210
        org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
211
212
213
214
215
216
217
218

    with vllm_runner(**sd_args) as vllm_model:
        if ensure_all_accepted or expected_acceptance_rate is not None:
            # Force log interval to be 0 to catch all metrics.
            stat_logger = vllm_model.model.llm_engine.stat_loggers[
                'prometheus']
            stat_logger.local_interval = -100

219
        sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

        if ensure_all_accepted or expected_acceptance_rate is not None:
            acceptance_rate = (stat_logger.metrics.
                               gauge_spec_decode_draft_acceptance_rate.labels(
                                   **stat_logger.labels)._value.get())

            if ensure_all_accepted:
                assert True
                # FIXME: ci fails to log acceptance rate.
                # It works locally.
                # assert acceptance_rate == 1.0

            if expected_acceptance_rate is not None:
                assert acceptance_rate >= expected_acceptance_rate - 1e-2

235
236
237
    # Only pass token entries, not the logprobs
    check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
                        outputs_1_lst=[out[0:2] for out in sd_outputs],
238
239
240
                        name_0="org",
                        name_1="sd")

241
242
243
244
    # Check logprobs if requested
    if logprobs is not None or prompt_logprobs is not None:
        check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)

245
246
247
248
249
250
251
252
253

def run_equality_correctness_test_tp(model,
                                     common_llm_kwargs,
                                     per_test_common_llm_kwargs,
                                     baseline_llm_kwargs,
                                     test_llm_kwargs,
                                     batch_size: int,
                                     max_output_len: int,
                                     seed: int = 0,
254
255
                                     temperature: float = 0.0,
                                     logprobs: Optional[int] = None):
256
257
    """Helper method that compares the outputs of both the baseline LLM and
    the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
258
    the same when temperature is zero.
259
    """
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
    arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
    env1 = env2 = None

    max_wait_seconds = 240
    results = []

    prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
    for args, env in ((arg1, env1), (arg2, env2)):
        with RemoteOpenAIServer(model,
                                args,
                                env_dict=env,
                                max_wait_seconds=max_wait_seconds) as server:
            client = server.get_client()

            completion = client.completions.create(model=model,
                                                   prompt=prompts,
                                                   max_tokens=max_output_len,
                                                   seed=seed,
279
280
                                                   temperature=temperature,
                                                   logprobs=logprobs)
281
282
283
284
285

            results.append({
                "test":
                "seeded_sampling",
                "text": [choice.text for choice in completion.choices],
286
                "logprobs": [choice.logprobs for choice in completion.choices],
287
288
289
290
291
292
293
294
295
                "finish_reason":
                [choice.finish_reason for choice in completion.choices],
                "usage":
                completion.usage,
            })

    n = len(results) // 2
    arg1_results = results[:n]
    arg2_results = results[n:]
296
297
298
299
    # Separate logprobs to avoid asserting exact equality.
    arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
    arg2_logprobs = [r.pop("logprobs") for r in arg2_results]

300
301
302
303
    for arg1_result, arg2_result in zip(arg1_results, arg2_results):
        assert arg1_result == arg2_result, (
            f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
            f"{arg1_result=} != {arg2_result=}")
304
305
306
307
    if logprobs:
        for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
            for l1, l2 in zip(logs1, logs2):
                assert l1.tokens == l2.tokens