test_logprobs.py 20.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import itertools
5
from collections.abc import Generator
6

7
import os
8
9
10
11
import pytest
import torch

from tests.v1.sample.utils import (
12
    BatchLogprobsComposition, BatchLogprobsSpecType,
13
14
15
    assert_incr_detok_str_matches_non_incr_detok_str,
    compute_correct_cumulative_logprob, get_test_batch)
from vllm import SamplingParams
16
from vllm.config import LogprobsMode
17

18
from ...conftest import HfRunner, VllmRunner
19
from ...utils import models_path_prefix
20

21
MODEL = os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
22
23
DTYPE = "half"

24
25
26
27
NONE = BatchLogprobsComposition.NONE
SAMPLE = BatchLogprobsComposition.SAMPLE
PROMPT = BatchLogprobsComposition.PROMPT
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
28

29
30
31
32
33
34

@pytest.fixture(
    scope="module",
    # Parameterize APC
    params=[False, True])
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
35
36
37
38
39
40
41
42
43
44
45
46
    with vllm_runner(
            MODEL,
            dtype=DTYPE,
            max_logprobs=7,
            # Very small number of batched tokens to ensure
            # that we test chunking.
            max_num_batched_tokens=16,
            max_num_seqs=16,
            max_model_len=128,
            enforce_eager=True,
            #TODO: enable this once we support it for
            # prompt logprobs.
47
            enable_prefix_caching=request.param,
48
            gpu_memory_utilization=0.4,  # up to 2 alive concurrently
49
50
51
52
53
    ) as vllm_model:
        yield vllm_model


@pytest.fixture(scope="module")
54
def hf_model(hf_runner) -> Generator[HfRunner, None, None]:
55
56
57
58
59
60
    with hf_runner(MODEL, dtype=DTYPE) as hf_model:
        yield hf_model


def _repeat_logprob_config(
    test_prompts,
61
62
    logprob_prompt_logprob_list: BatchLogprobsSpecType,
) -> BatchLogprobsSpecType:
63
    """Ensure each test prompt has a logprob config.
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    A logprob config specifies the optional (i.e.
    may-be-`None`) number of sample logprobs and
    the optional number of prompt logprobs.

    If more test prompts than logprob configs are
    provided, the provided logprob configs are
    tiled to match the number of test prompts.

    If fewer test prompts than logprob configs
    are provided, the list of logprob configs
    is truncated to match the number of test
    prompts.

    Otherwise, the list of logprob configs
    is returned as-is.

    Args:
      test_prompts: list of prompts under test
      logprob_prompt_logprob_list: list of
                            (optional num sample logprob,
                             optional num prompt logprob)
                             tuples
87

88
    Returns:
89
      list of
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
      (optional num sample logprob,optional num prompt logprob)
      tuples which is either identical to
      `logprob_prompt_logprob_list`, or else repeats
      `logprob_prompt_logprob_list` enough times to match the
      number of `test_prompts`, or else is truncated to match
      the number of `test_prompts`
    """
    num_test_prompts = len(test_prompts)
    # Make sure there is a logprobs configuration for each test prompt
    logprob_prompt_logprob_list = list(
        itertools.islice(itertools.cycle(logprob_prompt_logprob_list),
                         num_test_prompts))
    # Now the number of prompts should match the number of sample params combos
    assert num_test_prompts == len(logprob_prompt_logprob_list)
    return logprob_prompt_logprob_list


107
108
109
110
111
112
113
def _run_and_validate(
    vllm_model: VllmRunner,
    test_prompts: list[str],
    vllm_sampling_params: SamplingParams,
    hf_logprobs: list[list[torch.Tensor]],
    hf_outputs: list[tuple[list[int], str]],
    logprob_prompt_logprob_list: BatchLogprobsSpecType,
114
    temperature: float,
115
116
    max_tokens: int,
    do_apc: bool,
117
) -> None:
118
    vllm_results = vllm_model.llm.generate(
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        test_prompts, sampling_params=vllm_sampling_params)

    for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
            vllm_results, hf_logprobs, hf_outputs,
            logprob_prompt_logprob_list):

        # Extract request-level (prompt)logprobs config
        num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob

        # Test whether sampled token output is consistent between vLLM and HF
        # vLLM prompt+completion should match HF output
        if temperature == 0.0:
            assert (vllm_result.prompt_token_ids +
                    vllm_result.outputs[0].token_ids == hf_output[0])
        else:
            # Sampled tokens won't match if not greedy
            assert (vllm_result.prompt_token_ids == hf_output[0]
                    [:len(vllm_result.prompt_token_ids)])

        # Validate sample logprobs
        if num_top_logprobs is not None:
            assert num_top_logprobs is not None
            # Confirm that the structure of the sample logprobs in the result is
            # correct
            assert vllm_result.outputs[0].logprobs is not None
            assert len(vllm_result.outputs[0].logprobs) == max_tokens
            for logprobs, token_id in zip(vllm_result.outputs[0].logprobs,
                                          vllm_result.outputs[0].token_ids):
                assert logprobs is not None

                # Confirm that the output token appears among the logprobs
                assert token_id in logprobs
                token_in_topk = logprobs[token_id].rank <= num_top_logprobs

                # If the output token is not included in the top K
                # logprob, it can return 1 more data
                if token_in_topk and num_top_logprobs != 0:
                    assert len(logprobs) == num_top_logprobs
                else:
                    assert len(logprobs) == num_top_logprobs + 1

                if num_top_logprobs > 0:
                    # We should have an entry for each of the topk ranks
                    all_ranks = {lp.rank for lp in logprobs.values()}
                    assert all(r in all_ranks
                               for r in range(1, num_top_logprobs + 1))

            output_text = vllm_result.outputs[0].text
167
            output_string_from_most_likely_tokens_lst: list[str] = []
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            for top_logprobs in vllm_result.outputs[0].logprobs:
                top_logprob = next(iter(top_logprobs.values()))
                output_string_from_most_likely_tokens_lst.append(
                    top_logprob.decoded_token)

            output_string_from_most_likely_tokens = "".join(
                output_string_from_most_likely_tokens_lst)
            assert_incr_detok_str_matches_non_incr_detok_str(
                output_text, output_string_from_most_likely_tokens,
                "The output text from the top logprob for each token "
                "position should be the same as the output text in the "
                "result.")

            # Compare vLLM sample logprobs to HF
            vllm_sample_logprobs = vllm_result.outputs[0].logprobs
            for i, top_logprobs in enumerate(vllm_sample_logprobs):
                for token_id, sample_logprob in top_logprobs.items():
                    if temperature == 0.0 or i == 0:
                        logprob = sample_logprob.logprob
                        torch.testing.assert_close(
                            logprob,
                            hf_logprob[i][-1][token_id].item(),
                            atol=1e-2,
                            rtol=1e-2)
                    assert isinstance(
                        sample_logprob.decoded_token,
                        str), ("The token should be decoded by the time it is"
                               " returned to the user.")

            # At this point we know the sample logprobs are correct for this
            # request. Validate that cumulative_logprob is actually the sum.
            # For each request, assert that the returned cumulative logprob
            # matches the correct value, which is computed below.
            torch.testing.assert_close(
                vllm_result.outputs[0].cumulative_logprob,
                compute_correct_cumulative_logprob(vllm_result.outputs[0]),
                atol=1e-6,
                rtol=1e-6)
        else:
            # Logprobs disabled for this request; should be None
            assert vllm_result.outputs[0].logprobs is None

        # Validate prompt logprobs
        if num_top_prompt_logprobs is not None:
            # Confirm that structure of prompt logprobs in result is correct
            assert vllm_result.prompt_logprobs is not None
            # - The first prompt logprob is always None
            assert vllm_result.prompt_logprobs[0] is None
            # - Prompt logprobs are returned for all indices in
            #   the prompt
            assert len(vllm_result.prompt_logprobs) == len(
                vllm_result.prompt_token_ids)
            for prompt_logprobs, prompt_token_id in zip(
                    vllm_result.prompt_logprobs[1:],
                    vllm_result.prompt_token_ids[1:]):
                assert prompt_logprobs is not None

                # Confirm that the prompt token appears among the logprobs
                assert prompt_token_id in prompt_logprobs
                token_in_topk = prompt_logprobs[
                    prompt_token_id].rank <= num_top_prompt_logprobs

                # If the prompt token is not included in the top K
                # logprob, it can return 1 more data
                if token_in_topk and num_top_prompt_logprobs != 0:
                    assert len(prompt_logprobs) == num_top_prompt_logprobs
                else:
                    assert len(prompt_logprobs) == num_top_prompt_logprobs + 1

                if num_top_prompt_logprobs > 0:
                    # We should have an entry for each of the topk ranks
                    all_ranks = {lp.rank for lp in prompt_logprobs.values()}
                    assert all(r in all_ranks
                               for r in range(1, num_top_prompt_logprobs + 1))

            # Compare prompt logprobs to HF
            # The first prompt logprob is always None, so we compare it from
            # 1:.
            vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
            for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
                for token_id, logprob in vllm_prompt_logprob_dict.items():
                    torch.testing.assert_close(
                        logprob.logprob,
                        hf_logprob[0][i][token_id].item(),
                        atol=2e-2,
                        rtol=2e-2)
        else:
            assert vllm_result.prompt_logprobs is None


@pytest.mark.parametrize("batch_logprobs_composition",
259
                         [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
260
261
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
262
263
264
265
        hf_model, vllm_model,
        batch_logprobs_composition: BatchLogprobsComposition,
        temperature: float, example_prompts: list[str],
        monkeypatch: pytest.MonkeyPatch) -> None:
266
    """Test V1 Engine logprobs & prompt logprobs
267

268
269
270
271
272
273
274
275
276
277
278
279
    Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
    settings and validate that
    * The generated logprobs and prompt logprobs are consistent with the
      configuration settings, in terms of whether or not the logprobs
      (of either type) were requested and how many were requested
    * The generated logprobs are consistent with the generated tokens
    * The generated (prompt)logprobs are consistent with HuggingFace
      (prompt)logprobs, as a reference

    batch_logprobs_composition controls the logprobs configurations for
    requests in the batch under test.

280
281
282
283
    APC tests run two test iterations so that cache hits occur.

    To save time, only test one APC-enabled scenario
    (sample & prompt logprobs enabled, temperature>0.0).
284

285
    Args:
286
287
      hf_model: HuggingFace reference model fixture
      vllm_model: vLLM model fixture
288
      batch_logprobs_composition: logprobs configuration for test batch
289
290
      temperature: "temperature" sampling parameter
      example_prompts: example prompt fixture
291
    """
292
293
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
294
        do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
295
296
297
298
299
300
301
302
303
        if do_apc and (temperature < 2.0
                       or batch_logprobs_composition != SAMPLE_PROMPT):
            # Skip some test-cases to save time.
            pytest.skip()
        test_prompts = example_prompts

        max_tokens = 5
        hf_outputs = hf_model.generate_greedy(
            test_prompts,
304
            max_tokens=max_tokens,
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        )
        hf_logprobs = hf_model.generate_greedy_logprobs(
            test_prompts,
            max_tokens=max_tokens,
        )

        # Batch has mixed sample params
        # (different logprobs/prompt logprobs combos)
        logprob_prompt_logprob_list = get_test_batch(
            batch_logprobs_composition)

        # Ensure that each test prompt has a logprob config for testing
        logprob_prompt_logprob_list = _repeat_logprob_config(
            test_prompts, logprob_prompt_logprob_list)
        # Generate SamplingParams
        vllm_sampling_params = [
            SamplingParams(max_tokens=max_tokens,
                           logprobs=num_lp,
                           prompt_logprobs=num_plp,
                           temperature=temperature,
                           seed=1984)
            for num_lp, num_plp in logprob_prompt_logprob_list
        ]
        for _ in range(2 if do_apc else 1):
            _run_and_validate(
                vllm_model=vllm_model,
                test_prompts=test_prompts,
                vllm_sampling_params=vllm_sampling_params,
                hf_logprobs=hf_logprobs,
                hf_outputs=hf_outputs,
                logprob_prompt_logprob_list=logprob_prompt_logprob_list,
                temperature=temperature,
                max_tokens=max_tokens,
                do_apc=do_apc)


def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
342
343
    """vLLM v1 engine should fail a request with `logprobs > max_logprobs`
    Should also fail for `prompt_logprobs > max_logprobs`
344
    APC should not matter as this test checks basic request validation.
345
    """
346
347
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
348

349
350
351
352
353
354
355
        runner = VllmRunner(
            "facebook/opt-125m",
            max_logprobs=1,
            enable_prefix_caching=False,
            # 2 other llms alive during whole session
            gpu_memory_utilization=0.15,
            max_model_len=256)
356
357
358
        vllm_sampling_params = SamplingParams(logprobs=1)
        # should pass
        runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
359

360
361
362
363
        bad_sampling_params = SamplingParams(logprobs=2)
        with pytest.raises(ValueError):
            runner.generate(["Hello world"],
                            sampling_params=bad_sampling_params)
364
365


366
367
def test_none_logprobs(vllm_model, example_prompts,
                       monkeypatch: pytest.MonkeyPatch):
368
    """Engine should return `logprobs` and `prompt_logprobs` as `None`
369

370
371
372
373
    Args:
      vllm_model: vLLM model fixture
      example_prompts: list of example prompts (test fixture)
    """
374
375
376
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        max_tokens = 5
377

378
379
380
381
382
383
        sampling_params_logprobs_none = SamplingParams(
            max_tokens=max_tokens,
            logprobs=None,
            prompt_logprobs=None,
            temperature=0.0,
        )
384
        results_logprobs_none = vllm_model.llm.generate(
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
            example_prompts,
            sampling_params=sampling_params_logprobs_none,
        )

        for i in range(len(results_logprobs_none)):
            # Check sample logprobs are None
            assert results_logprobs_none[i].outputs[0].logprobs is None
            assert results_logprobs_none[i].outputs[
                0].cumulative_logprob is None
            # Check prompt logprobs are None
            assert results_logprobs_none[i].prompt_logprobs is None


def test_zero_logprobs(vllm_model, example_prompts,
                       monkeypatch: pytest.MonkeyPatch):
400
    """Engine should return sampled token and prompt token logprobs
401

402
403
404
405
    Args:
      vllm_model: vLLM model fixture
      example_prompts: list of example prompts (test fixture)
    """
406
407
408
409
410
411
412
413
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        max_tokens = 5

        sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
                                                       logprobs=0,
                                                       prompt_logprobs=0,
                                                       temperature=0.0)
414
        results_logprobs_zero = vllm_model.llm.generate(
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
            example_prompts, sampling_params=sampling_params_logprobs_zero)

        for i in range(len(results_logprobs_zero)):
            # Check that there is one sample logprob dict for each
            # sample token
            logprobs = results_logprobs_zero[i].outputs[0].logprobs
            prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
            sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
            prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
            assert logprobs is not None
            assert len(sampled_token_ids) == len(logprobs)
            assert results_logprobs_zero[i].outputs[
                0].cumulative_logprob is not None
            # Check that there is one prompt logprob dict for each
            # prompt token
            assert prompt_logprobs is not None
            assert len(prompt_token_ids) == len(prompt_logprobs)
432
433


434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
    """Engine should return all vocabulary logprobs

    Args:
      example_prompts: list of example prompts (test fixture)
    """
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")
        runner = VllmRunner(
            "facebook/opt-125m",
            max_logprobs=-1,
            enable_prefix_caching=False,
            # 2 other llms alive during whole session
            gpu_memory_utilization=0.15,
            max_model_len=256)
        sampling_params_logprobs_all = SamplingParams(max_tokens=5,
                                                      logprobs=-1)
        results_logprobs_all = runner.llm.generate(
            example_prompts, sampling_params=sampling_params_logprobs_all)
        vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
        for i in range(len(results_logprobs_all)):
            logprobs = results_logprobs_all[i].outputs[0].logprobs
            assert logprobs is not None
            for logprob in logprobs:
                assert len(logprob) == vocab_size


461
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
def test_logprobs_mode(logprobs_mode: LogprobsMode,
                       monkeypatch: pytest.MonkeyPatch):
    """Test with LLM engine with different logprobs_mode.
    For logprobs, we should have non-positive values.
    For logits, we should expect at least one positive values.
    """
    from vllm import LLM
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

        llm = LLM(
            "facebook/opt-125m",
            max_logprobs=5,
            enable_prefix_caching=False,
            # 2 other llms alive during whole session
            gpu_memory_utilization=0.05,
            max_model_len=16,
            logprobs_mode=logprobs_mode)
        vllm_sampling_params = SamplingParams(logprobs=1)
        results = llm.generate(["Hello world"],
                               sampling_params=vllm_sampling_params)

        total_token_with_logprobs = 0
        positive_values = 0
        for output in results[0].outputs:
            for logprobs in output.logprobs:
                for token_id in logprobs:
                    logprob = logprobs[token_id]
490
491
                    if logprobs_mode in (LogprobsMode.RAW_LOGPROBS,
                                         LogprobsMode.PROCESSED_LOGPROBS):
492
493
494
495
496
                        assert logprob.logprob <= 0
                    if logprob.logprob > 0:
                        positive_values = positive_values + 1
                    total_token_with_logprobs = total_token_with_logprobs + 1
        assert total_token_with_logprobs >= len(results[0].outputs)
497
498
        if logprobs_mode in (LogprobsMode.RAW_LOGITS,
                             LogprobsMode.PROCESSED_LOGITS):
499
500
            assert positive_values > 0
        del llm