"vllm/vscode:/vscode.git/clone" did not exist on "8001970ce77dffa9ee73abae520b91d479f7cd17"
test_logprobs.py 19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import itertools
5
from collections.abc import Generator
6
from typing import get_args
7
8
9
10
11

import pytest
import torch

from tests.v1.sample.utils import (
12
13
    BatchLogprobsComposition,
    BatchLogprobsSpecType,
14
    assert_incr_detok_str_matches_non_incr_detok_str,
15
16
17
    compute_correct_cumulative_logprob,
    get_test_batch,
)
18
from vllm import SamplingParams
19
from vllm.config import LogprobsMode
20

21
from ...conftest import HfRunner, VllmRunner
22

23
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
24
25
DTYPE = "half"

26
27
28
29
NONE = BatchLogprobsComposition.NONE
SAMPLE = BatchLogprobsComposition.SAMPLE
PROMPT = BatchLogprobsComposition.PROMPT
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
30

31
32
33
34

@pytest.fixture(
    scope="module",
    # Parameterize APC
35
36
    params=[False, True],
)
37
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
38
    with vllm_runner(
39
40
41
42
43
44
45
46
47
48
49
50
51
        MODEL,
        dtype=DTYPE,
        max_logprobs=7,
        # Very small number of batched tokens to ensure
        # that we test chunking.
        max_num_batched_tokens=16,
        max_num_seqs=16,
        max_model_len=128,
        enforce_eager=True,
        # TODO: enable this once we support it for
        # prompt logprobs.
        enable_prefix_caching=request.param,
        gpu_memory_utilization=0.4,  # up to 2 alive concurrently
52
53
54
55
56
    ) as vllm_model:
        yield vllm_model


@pytest.fixture(scope="module")
57
def hf_model(hf_runner) -> Generator[HfRunner, None, None]:
58
59
60
61
62
63
    with hf_runner(MODEL, dtype=DTYPE) as hf_model:
        yield hf_model


def _repeat_logprob_config(
    test_prompts,
64
65
    logprob_prompt_logprob_list: BatchLogprobsSpecType,
) -> BatchLogprobsSpecType:
66
    """Ensure each test prompt has a logprob config.
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    A logprob config specifies the optional (i.e.
    may-be-`None`) number of sample logprobs and
    the optional number of prompt logprobs.

    If more test prompts than logprob configs are
    provided, the provided logprob configs are
    tiled to match the number of test prompts.

    If fewer test prompts than logprob configs
    are provided, the list of logprob configs
    is truncated to match the number of test
    prompts.

    Otherwise, the list of logprob configs
    is returned as-is.

    Args:
      test_prompts: list of prompts under test
      logprob_prompt_logprob_list: list of
                            (optional num sample logprob,
                             optional num prompt logprob)
                             tuples
90

91
    Returns:
92
      list of
93
94
95
96
97
98
99
100
101
102
      (optional num sample logprob,optional num prompt logprob)
      tuples which is either identical to
      `logprob_prompt_logprob_list`, or else repeats
      `logprob_prompt_logprob_list` enough times to match the
      number of `test_prompts`, or else is truncated to match
      the number of `test_prompts`
    """
    num_test_prompts = len(test_prompts)
    # Make sure there is a logprobs configuration for each test prompt
    logprob_prompt_logprob_list = list(
103
104
        itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts)
    )
105
106
107
108
109
    # Now the number of prompts should match the number of sample params combos
    assert num_test_prompts == len(logprob_prompt_logprob_list)
    return logprob_prompt_logprob_list


110
111
112
113
114
115
116
def _run_and_validate(
    vllm_model: VllmRunner,
    test_prompts: list[str],
    vllm_sampling_params: SamplingParams,
    hf_logprobs: list[list[torch.Tensor]],
    hf_outputs: list[tuple[list[int], str]],
    logprob_prompt_logprob_list: BatchLogprobsSpecType,
117
    temperature: float,
118
119
    max_tokens: int,
    do_apc: bool,
120
) -> None:
121
    vllm_results = vllm_model.llm.generate(
122
123
        test_prompts, sampling_params=vllm_sampling_params
    )
124
125

    for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
126
127
        vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list
    ):
128
129
130
131
132
133
        # Extract request-level (prompt)logprobs config
        num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob

        # Test whether sampled token output is consistent between vLLM and HF
        # vLLM prompt+completion should match HF output
        if temperature == 0.0:
134
135
136
137
            assert (
                vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids
                == hf_output[0]
            )
138
139
        else:
            # Sampled tokens won't match if not greedy
140
141
142
143
            assert (
                vllm_result.prompt_token_ids
                == hf_output[0][: len(vllm_result.prompt_token_ids)]
            )
144
145
146
147
148
149
150
151

        # Validate sample logprobs
        if num_top_logprobs is not None:
            assert num_top_logprobs is not None
            # Confirm that the structure of the sample logprobs in the result is
            # correct
            assert vllm_result.outputs[0].logprobs is not None
            assert len(vllm_result.outputs[0].logprobs) == max_tokens
152
153
154
            for logprobs, token_id in zip(
                vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids
            ):
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                assert logprobs is not None

                # Confirm that the output token appears among the logprobs
                assert token_id in logprobs
                token_in_topk = logprobs[token_id].rank <= num_top_logprobs

                # If the output token is not included in the top K
                # logprob, it can return 1 more data
                if token_in_topk and num_top_logprobs != 0:
                    assert len(logprobs) == num_top_logprobs
                else:
                    assert len(logprobs) == num_top_logprobs + 1

                if num_top_logprobs > 0:
                    # We should have an entry for each of the topk ranks
                    all_ranks = {lp.rank for lp in logprobs.values()}
171
                    assert all(r in all_ranks for r in range(1, num_top_logprobs + 1))
172
173

            output_text = vllm_result.outputs[0].text
174
            output_string_from_most_likely_tokens_lst: list[str] = []
175
176
177
            for top_logprobs in vllm_result.outputs[0].logprobs:
                top_logprob = next(iter(top_logprobs.values()))
                output_string_from_most_likely_tokens_lst.append(
178
179
                    top_logprob.decoded_token
                )
180
181

            output_string_from_most_likely_tokens = "".join(
182
183
                output_string_from_most_likely_tokens_lst
            )
184
            assert_incr_detok_str_matches_non_incr_detok_str(
185
186
                output_text,
                output_string_from_most_likely_tokens,
187
188
                "The output text from the top logprob for each token "
                "position should be the same as the output text in the "
189
190
                "result.",
            )
191
192
193
194
195
196
197
198
199
200
201

            # Compare vLLM sample logprobs to HF
            vllm_sample_logprobs = vllm_result.outputs[0].logprobs
            for i, top_logprobs in enumerate(vllm_sample_logprobs):
                for token_id, sample_logprob in top_logprobs.items():
                    if temperature == 0.0 or i == 0:
                        logprob = sample_logprob.logprob
                        torch.testing.assert_close(
                            logprob,
                            hf_logprob[i][-1][token_id].item(),
                            atol=1e-2,
202
203
204
205
206
207
                            rtol=1e-2,
                        )
                    assert isinstance(sample_logprob.decoded_token, str), (
                        "The token should be decoded by the time it is"
                        " returned to the user."
                    )
208
209
210
211
212
213
214
215
216

            # At this point we know the sample logprobs are correct for this
            # request. Validate that cumulative_logprob is actually the sum.
            # For each request, assert that the returned cumulative logprob
            # matches the correct value, which is computed below.
            torch.testing.assert_close(
                vllm_result.outputs[0].cumulative_logprob,
                compute_correct_cumulative_logprob(vllm_result.outputs[0]),
                atol=1e-6,
217
218
                rtol=1e-6,
            )
219
220
221
222
223
224
225
226
227
228
229
230
        else:
            # Logprobs disabled for this request; should be None
            assert vllm_result.outputs[0].logprobs is None

        # Validate prompt logprobs
        if num_top_prompt_logprobs is not None:
            # Confirm that structure of prompt logprobs in result is correct
            assert vllm_result.prompt_logprobs is not None
            # - The first prompt logprob is always None
            assert vllm_result.prompt_logprobs[0] is None
            # - Prompt logprobs are returned for all indices in
            #   the prompt
231
            assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids)
232
            for prompt_logprobs, prompt_token_id in zip(
233
234
                vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:]
            ):
235
236
237
238
                assert prompt_logprobs is not None

                # Confirm that the prompt token appears among the logprobs
                assert prompt_token_id in prompt_logprobs
239
240
241
                token_in_topk = (
                    prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs
                )
242
243
244
245
246
247
248
249
250
251
252

                # If the prompt token is not included in the top K
                # logprob, it can return 1 more data
                if token_in_topk and num_top_prompt_logprobs != 0:
                    assert len(prompt_logprobs) == num_top_prompt_logprobs
                else:
                    assert len(prompt_logprobs) == num_top_prompt_logprobs + 1

                if num_top_prompt_logprobs > 0:
                    # We should have an entry for each of the topk ranks
                    all_ranks = {lp.rank for lp in prompt_logprobs.values()}
253
254
255
                    assert all(
                        r in all_ranks for r in range(1, num_top_prompt_logprobs + 1)
                    )
256
257
258
259
260
261
262
263
264
265
266

            # Compare prompt logprobs to HF
            # The first prompt logprob is always None, so we compare it from
            # 1:.
            vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
            for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
                for token_id, logprob in vllm_prompt_logprob_dict.items():
                    torch.testing.assert_close(
                        logprob.logprob,
                        hf_logprob[0][i][token_id].item(),
                        atol=2e-2,
267
268
                        rtol=2e-2,
                    )
269
270
271
272
        else:
            assert vllm_result.prompt_logprobs is None


273
274
275
@pytest.mark.parametrize(
    "batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]
)
276
277
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
278
279
280
281
282
283
    hf_model,
    vllm_model,
    batch_logprobs_composition: BatchLogprobsComposition,
    temperature: float,
    example_prompts: list[str],
) -> None:
284
    """Test V1 Engine logprobs & prompt logprobs
285

286
287
288
289
290
291
292
293
294
295
296
297
    Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
    settings and validate that
    * The generated logprobs and prompt logprobs are consistent with the
      configuration settings, in terms of whether or not the logprobs
      (of either type) were requested and how many were requested
    * The generated logprobs are consistent with the generated tokens
    * The generated (prompt)logprobs are consistent with HuggingFace
      (prompt)logprobs, as a reference

    batch_logprobs_composition controls the logprobs configurations for
    requests in the batch under test.

298
299
300
301
    APC tests run two test iterations so that cache hits occur.

    To save time, only test one APC-enabled scenario
    (sample & prompt logprobs enabled, temperature>0.0).
302

303
    Args:
304
305
      hf_model: HuggingFace reference model fixture
      vllm_model: vLLM model fixture
306
      batch_logprobs_composition: logprobs configuration for test batch
307
308
      temperature: "temperature" sampling parameter
      example_prompts: example prompt fixture
309
    """
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
    if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT):
        # Skip some test-cases to save time.
        pytest.skip()
    test_prompts = example_prompts

    max_tokens = 5
    hf_outputs = hf_model.generate_greedy(
        test_prompts,
        max_tokens=max_tokens,
    )
    hf_logprobs = hf_model.generate_greedy_logprobs(
        test_prompts,
        max_tokens=max_tokens,
    )

    # Batch has mixed sample params
    # (different logprobs/prompt logprobs combos)
    logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)

    # Ensure that each test prompt has a logprob config for testing
    logprob_prompt_logprob_list = _repeat_logprob_config(
        test_prompts, logprob_prompt_logprob_list
    )
    # Generate SamplingParams
    vllm_sampling_params = [
        SamplingParams(
337
            max_tokens=max_tokens,
338
339
340
341
            logprobs=num_lp,
            prompt_logprobs=num_plp,
            temperature=temperature,
            seed=1984,
342
        )
343
344
345
346
347
348
349
350
351
352
353
        for num_lp, num_plp in logprob_prompt_logprob_list
    ]
    for _ in range(2 if do_apc else 1):
        _run_and_validate(
            vllm_model=vllm_model,
            test_prompts=test_prompts,
            vllm_sampling_params=vllm_sampling_params,
            hf_logprobs=hf_logprobs,
            hf_outputs=hf_outputs,
            logprob_prompt_logprob_list=logprob_prompt_logprob_list,
            temperature=temperature,
354
            max_tokens=max_tokens,
355
            do_apc=do_apc,
356
357
358
        )


359
def test_max_logprobs():
360
361
    """vLLM v1 engine should fail a request with `logprobs > max_logprobs`
    Should also fail for `prompt_logprobs > max_logprobs`
362
    APC should not matter as this test checks basic request validation.
363
    """
364
365
366
367
368
369
370
371
372
373
374
    runner = VllmRunner(
        "facebook/opt-125m",
        max_logprobs=1,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.15,
        max_model_len=256,
    )
    vllm_sampling_params = SamplingParams(logprobs=1)
    # should pass
    runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
375

376
377
378
    bad_sampling_params = SamplingParams(logprobs=2)
    with pytest.raises(ValueError):
        runner.generate(["Hello world"], sampling_params=bad_sampling_params)
379
380


381
def test_none_logprobs(vllm_model, example_prompts):
382
    """Engine should return `logprobs` and `prompt_logprobs` as `None`
383

384
385
386
387
    Args:
      vllm_model: vLLM model fixture
      example_prompts: list of example prompts (test fixture)
    """
388
    max_tokens = 5
389

390
391
392
393
394
395
396
397
398
399
    sampling_params_logprobs_none = SamplingParams(
        max_tokens=max_tokens,
        logprobs=None,
        prompt_logprobs=None,
        temperature=0.0,
    )
    results_logprobs_none = vllm_model.llm.generate(
        example_prompts,
        sampling_params=sampling_params_logprobs_none,
    )
400

401
402
403
404
405
406
    for i in range(len(results_logprobs_none)):
        # Check sample logprobs are None
        assert results_logprobs_none[i].outputs[0].logprobs is None
        assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
        # Check prompt logprobs are None
        assert results_logprobs_none[i].prompt_logprobs is None
407
408


409
def test_zero_logprobs(vllm_model, example_prompts):
410
    """Engine should return sampled token and prompt token logprobs
411

412
413
414
415
    Args:
      vllm_model: vLLM model fixture
      example_prompts: list of example prompts (test fixture)
    """
416
    max_tokens = 5
417

418
419
420
421
422
423
    sampling_params_logprobs_zero = SamplingParams(
        max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
    )
    results_logprobs_zero = vllm_model.llm.generate(
        example_prompts, sampling_params=sampling_params_logprobs_zero
    )
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    for i in range(len(results_logprobs_zero)):
        # Check that there is one sample logprob dict for each
        # sample token
        logprobs = results_logprobs_zero[i].outputs[0].logprobs
        prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
        sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
        prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
        assert logprobs is not None
        assert len(sampled_token_ids) == len(logprobs)
        assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
        # Check that there is one prompt logprob dict for each
        # prompt token
        assert prompt_logprobs is not None
        assert len(prompt_token_ids) == len(prompt_logprobs)


def test_all_logprobs(example_prompts):
442
    """Engine should return all vocabulary logprobs and prompt logprobs
443
444
445
446

    Args:
      example_prompts: list of example prompts (test fixture)
    """
447
448
449
450
451
452
453
454
    runner = VllmRunner(
        "facebook/opt-125m",
        max_logprobs=-1,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.15,
        max_model_len=256,
    )
455

456
457
458
459
460
461
462
    sampling_params_logprobs_all = SamplingParams(
        max_tokens=5, logprobs=-1, prompt_logprobs=-1
    )
    results_logprobs_all = runner.llm.generate(
        example_prompts, sampling_params=sampling_params_logprobs_all
    )
    vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
463

464
465
466
467
468
469
470
471
472
473
    for i in range(len(results_logprobs_all)):
        logprobs = results_logprobs_all[i].outputs[0].logprobs
        prompt_logprobs = results_logprobs_all[i].prompt_logprobs
        assert logprobs is not None
        for logprob in logprobs:
            assert len(logprob) == vocab_size
        assert prompt_logprobs is not None
        assert prompt_logprobs[0] is None
        for prompt_logprob in prompt_logprobs[1:]:
            assert len(prompt_logprob) == vocab_size
474
475


476
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
477
def test_logprobs_mode(logprobs_mode: LogprobsMode):
478
479
480
481
482
    """Test with LLM engine with different logprobs_mode.
    For logprobs, we should have non-positive values.
    For logits, we should expect at least one positive values.
    """
    from vllm import LLM
483

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    llm = LLM(
        "facebook/opt-125m",
        max_logprobs=5,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.05,
        max_model_len=16,
        logprobs_mode=logprobs_mode,
    )
    vllm_sampling_params = SamplingParams(logprobs=1)
    results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)

    total_token_with_logprobs = 0
    positive_values = 0
    for output in results[0].outputs:
        for logprobs in output.logprobs:
            for token_id in logprobs:
                logprob = logprobs[token_id]
                if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
                    assert logprob.logprob <= 0
                if logprob.logprob > 0:
                    positive_values = positive_values + 1
                total_token_with_logprobs = total_token_with_logprobs + 1
    assert total_token_with_logprobs >= len(results[0].outputs)
    if logprobs_mode in ("raw_logits", "processed_logits"):
        assert positive_values > 0
    del llm