test_logprobs.py 21.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import itertools
5
import math
6
from collections.abc import Generator
7
from typing import get_args
8
9
10
11

import pytest
import torch

12
from tests.utils import large_gpu_mark
13
from tests.v1.sample.utils import (
14
15
    BatchLogprobsComposition,
    BatchLogprobsSpecType,
16
    assert_incr_detok_str_matches_non_incr_detok_str,
17
18
19
    compute_correct_cumulative_logprob,
    get_test_batch,
)
20
from vllm import SamplingParams
21
from vllm.config.model import LogprobsMode
22
from vllm.distributed import cleanup_dist_env_and_memory
23

24
from ...conftest import HfRunner, VllmRunner
25

26
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
27
28
DTYPE = "half"

29
30
31
32
NONE = BatchLogprobsComposition.NONE
SAMPLE = BatchLogprobsComposition.SAMPLE
PROMPT = BatchLogprobsComposition.PROMPT
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
33

34
35
36
37

@pytest.fixture(
    scope="module",
    # Parameterize APC
38
39
    params=[False, True],
)
40
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
41
    with vllm_runner(
42
43
44
45
46
47
48
49
50
51
52
53
54
        MODEL,
        dtype=DTYPE,
        max_logprobs=7,
        # Very small number of batched tokens to ensure
        # that we test chunking.
        max_num_batched_tokens=16,
        max_num_seqs=16,
        max_model_len=128,
        enforce_eager=True,
        # TODO: enable this once we support it for
        # prompt logprobs.
        enable_prefix_caching=request.param,
        gpu_memory_utilization=0.4,  # up to 2 alive concurrently
55
56
57
58
59
    ) as vllm_model:
        yield vllm_model


@pytest.fixture(scope="module")
60
def hf_model(hf_runner) -> Generator[HfRunner, None, None]:
61
62
63
64
65
66
    with hf_runner(MODEL, dtype=DTYPE) as hf_model:
        yield hf_model


def _repeat_logprob_config(
    test_prompts,
67
68
    logprob_prompt_logprob_list: BatchLogprobsSpecType,
) -> BatchLogprobsSpecType:
69
    """Ensure each test prompt has a logprob config.
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    A logprob config specifies the optional (i.e.
    may-be-`None`) number of sample logprobs and
    the optional number of prompt logprobs.

    If more test prompts than logprob configs are
    provided, the provided logprob configs are
    tiled to match the number of test prompts.

    If fewer test prompts than logprob configs
    are provided, the list of logprob configs
    is truncated to match the number of test
    prompts.

    Otherwise, the list of logprob configs
    is returned as-is.

    Args:
      test_prompts: list of prompts under test
      logprob_prompt_logprob_list: list of
                            (optional num sample logprob,
                             optional num prompt logprob)
                             tuples
93

94
    Returns:
95
      list of
96
97
98
99
100
101
102
103
104
105
      (optional num sample logprob,optional num prompt logprob)
      tuples which is either identical to
      `logprob_prompt_logprob_list`, or else repeats
      `logprob_prompt_logprob_list` enough times to match the
      number of `test_prompts`, or else is truncated to match
      the number of `test_prompts`
    """
    num_test_prompts = len(test_prompts)
    # Make sure there is a logprobs configuration for each test prompt
    logprob_prompt_logprob_list = list(
106
107
        itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts)
    )
108
109
110
111
112
    # Now the number of prompts should match the number of sample params combos
    assert num_test_prompts == len(logprob_prompt_logprob_list)
    return logprob_prompt_logprob_list


113
114
115
116
117
118
119
def _run_and_validate(
    vllm_model: VllmRunner,
    test_prompts: list[str],
    vllm_sampling_params: SamplingParams,
    hf_logprobs: list[list[torch.Tensor]],
    hf_outputs: list[tuple[list[int], str]],
    logprob_prompt_logprob_list: BatchLogprobsSpecType,
120
    temperature: float,
121
122
    max_tokens: int,
    do_apc: bool,
123
) -> None:
124
    vllm_results = vllm_model.llm.generate(
125
126
        test_prompts, sampling_params=vllm_sampling_params
    )
127
128

    for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
129
130
        vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list
    ):
131
132
133
134
135
136
        # Extract request-level (prompt)logprobs config
        num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob

        # Test whether sampled token output is consistent between vLLM and HF
        # vLLM prompt+completion should match HF output
        if temperature == 0.0:
137
138
139
140
            assert (
                vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids
                == hf_output[0]
            )
141
142
        else:
            # Sampled tokens won't match if not greedy
143
144
145
146
            assert (
                vllm_result.prompt_token_ids
                == hf_output[0][: len(vllm_result.prompt_token_ids)]
            )
147
148
149
150
151
152
153
154

        # Validate sample logprobs
        if num_top_logprobs is not None:
            assert num_top_logprobs is not None
            # Confirm that the structure of the sample logprobs in the result is
            # correct
            assert vllm_result.outputs[0].logprobs is not None
            assert len(vllm_result.outputs[0].logprobs) == max_tokens
155
156
157
            for logprobs, token_id in zip(
                vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids
            ):
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                assert logprobs is not None

                # Confirm that the output token appears among the logprobs
                assert token_id in logprobs
                token_in_topk = logprobs[token_id].rank <= num_top_logprobs

                # If the output token is not included in the top K
                # logprob, it can return 1 more data
                if token_in_topk and num_top_logprobs != 0:
                    assert len(logprobs) == num_top_logprobs
                else:
                    assert len(logprobs) == num_top_logprobs + 1

                if num_top_logprobs > 0:
                    # We should have an entry for each of the topk ranks
                    all_ranks = {lp.rank for lp in logprobs.values()}
174
                    assert all(r in all_ranks for r in range(1, num_top_logprobs + 1))
175
176

            output_text = vllm_result.outputs[0].text
177
            output_string_from_most_likely_tokens_lst: list[str] = []
178
179
180
            for top_logprobs in vllm_result.outputs[0].logprobs:
                top_logprob = next(iter(top_logprobs.values()))
                output_string_from_most_likely_tokens_lst.append(
181
182
                    top_logprob.decoded_token
                )
183
184

            output_string_from_most_likely_tokens = "".join(
185
186
                output_string_from_most_likely_tokens_lst
            )
187
            assert_incr_detok_str_matches_non_incr_detok_str(
188
189
                output_text,
                output_string_from_most_likely_tokens,
190
191
                "The output text from the top logprob for each token "
                "position should be the same as the output text in the "
192
193
                "result.",
            )
194
195
196
197
198
199
200
201
202
203
204

            # Compare vLLM sample logprobs to HF
            vllm_sample_logprobs = vllm_result.outputs[0].logprobs
            for i, top_logprobs in enumerate(vllm_sample_logprobs):
                for token_id, sample_logprob in top_logprobs.items():
                    if temperature == 0.0 or i == 0:
                        logprob = sample_logprob.logprob
                        torch.testing.assert_close(
                            logprob,
                            hf_logprob[i][-1][token_id].item(),
                            atol=1e-2,
205
206
207
208
209
210
                            rtol=1e-2,
                        )
                    assert isinstance(sample_logprob.decoded_token, str), (
                        "The token should be decoded by the time it is"
                        " returned to the user."
                    )
211
212
213
214
215
216
217
218
219

            # At this point we know the sample logprobs are correct for this
            # request. Validate that cumulative_logprob is actually the sum.
            # For each request, assert that the returned cumulative logprob
            # matches the correct value, which is computed below.
            torch.testing.assert_close(
                vllm_result.outputs[0].cumulative_logprob,
                compute_correct_cumulative_logprob(vllm_result.outputs[0]),
                atol=1e-6,
220
221
                rtol=1e-6,
            )
222
223
224
225
226
227
228
229
230
231
232
233
        else:
            # Logprobs disabled for this request; should be None
            assert vllm_result.outputs[0].logprobs is None

        # Validate prompt logprobs
        if num_top_prompt_logprobs is not None:
            # Confirm that structure of prompt logprobs in result is correct
            assert vllm_result.prompt_logprobs is not None
            # - The first prompt logprob is always None
            assert vllm_result.prompt_logprobs[0] is None
            # - Prompt logprobs are returned for all indices in
            #   the prompt
234
            assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids)
235
            for prompt_logprobs, prompt_token_id in zip(
236
237
                vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:]
            ):
238
239
240
241
                assert prompt_logprobs is not None

                # Confirm that the prompt token appears among the logprobs
                assert prompt_token_id in prompt_logprobs
242
243
244
                token_in_topk = (
                    prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs
                )
245
246
247
248
249
250
251
252
253
254
255

                # If the prompt token is not included in the top K
                # logprob, it can return 1 more data
                if token_in_topk and num_top_prompt_logprobs != 0:
                    assert len(prompt_logprobs) == num_top_prompt_logprobs
                else:
                    assert len(prompt_logprobs) == num_top_prompt_logprobs + 1

                if num_top_prompt_logprobs > 0:
                    # We should have an entry for each of the topk ranks
                    all_ranks = {lp.rank for lp in prompt_logprobs.values()}
256
257
258
                    assert all(
                        r in all_ranks for r in range(1, num_top_prompt_logprobs + 1)
                    )
259
260
261
262
263
264
265
266
267
268
269

            # Compare prompt logprobs to HF
            # The first prompt logprob is always None, so we compare it from
            # 1:.
            vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
            for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
                for token_id, logprob in vllm_prompt_logprob_dict.items():
                    torch.testing.assert_close(
                        logprob.logprob,
                        hf_logprob[0][i][token_id].item(),
                        atol=2e-2,
270
271
                        rtol=2e-2,
                    )
272
273
274
275
        else:
            assert vllm_result.prompt_logprobs is None


276
277
278
@pytest.mark.parametrize(
    "batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]
)
279
280
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
281
282
283
284
285
286
    hf_model,
    vllm_model,
    batch_logprobs_composition: BatchLogprobsComposition,
    temperature: float,
    example_prompts: list[str],
) -> None:
287
    """Test V1 Engine logprobs & prompt logprobs
288

289
290
291
292
293
294
295
296
297
298
299
300
    Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
    settings and validate that
    * The generated logprobs and prompt logprobs are consistent with the
      configuration settings, in terms of whether or not the logprobs
      (of either type) were requested and how many were requested
    * The generated logprobs are consistent with the generated tokens
    * The generated (prompt)logprobs are consistent with HuggingFace
      (prompt)logprobs, as a reference

    batch_logprobs_composition controls the logprobs configurations for
    requests in the batch under test.

301
302
303
304
    APC tests run two test iterations so that cache hits occur.

    To save time, only test one APC-enabled scenario
    (sample & prompt logprobs enabled, temperature>0.0).
305

306
    Args:
307
308
      hf_model: HuggingFace reference model fixture
      vllm_model: vLLM model fixture
309
      batch_logprobs_composition: logprobs configuration for test batch
310
311
      temperature: "temperature" sampling parameter
      example_prompts: example prompt fixture
312
    """
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
    if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT):
        # Skip some test-cases to save time.
        pytest.skip()
    test_prompts = example_prompts

    max_tokens = 5
    hf_outputs = hf_model.generate_greedy(
        test_prompts,
        max_tokens=max_tokens,
    )
    hf_logprobs = hf_model.generate_greedy_logprobs(
        test_prompts,
        max_tokens=max_tokens,
    )

    # Batch has mixed sample params
    # (different logprobs/prompt logprobs combos)
    logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)

    # Ensure that each test prompt has a logprob config for testing
    logprob_prompt_logprob_list = _repeat_logprob_config(
        test_prompts, logprob_prompt_logprob_list
    )
    # Generate SamplingParams
    vllm_sampling_params = [
        SamplingParams(
340
            max_tokens=max_tokens,
341
342
343
344
            logprobs=num_lp,
            prompt_logprobs=num_plp,
            temperature=temperature,
            seed=1984,
345
        )
346
347
348
349
350
351
352
353
354
355
356
        for num_lp, num_plp in logprob_prompt_logprob_list
    ]
    for _ in range(2 if do_apc else 1):
        _run_and_validate(
            vllm_model=vllm_model,
            test_prompts=test_prompts,
            vllm_sampling_params=vllm_sampling_params,
            hf_logprobs=hf_logprobs,
            hf_outputs=hf_outputs,
            logprob_prompt_logprob_list=logprob_prompt_logprob_list,
            temperature=temperature,
357
            max_tokens=max_tokens,
358
            do_apc=do_apc,
359
360
361
        )


362
def test_max_logprobs():
363
364
    """vLLM v1 engine should fail a request with `logprobs > max_logprobs`
    Should also fail for `prompt_logprobs > max_logprobs`
365
    APC should not matter as this test checks basic request validation.
366
    """
367
368
369
370
371
372
373
374
375
376
377
    runner = VllmRunner(
        "facebook/opt-125m",
        max_logprobs=1,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.15,
        max_model_len=256,
    )
    vllm_sampling_params = SamplingParams(logprobs=1)
    # should pass
    runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
378

379
380
381
    bad_sampling_params = SamplingParams(logprobs=2)
    with pytest.raises(ValueError):
        runner.generate(["Hello world"], sampling_params=bad_sampling_params)
382
383


384
def test_none_logprobs(vllm_model, example_prompts):
385
    """Engine should return `logprobs` and `prompt_logprobs` as `None`
386

387
388
389
390
    Args:
      vllm_model: vLLM model fixture
      example_prompts: list of example prompts (test fixture)
    """
391
    max_tokens = 5
392

393
394
395
396
397
398
399
400
401
402
    sampling_params_logprobs_none = SamplingParams(
        max_tokens=max_tokens,
        logprobs=None,
        prompt_logprobs=None,
        temperature=0.0,
    )
    results_logprobs_none = vllm_model.llm.generate(
        example_prompts,
        sampling_params=sampling_params_logprobs_none,
    )
403

404
405
406
407
408
409
    for i in range(len(results_logprobs_none)):
        # Check sample logprobs are None
        assert results_logprobs_none[i].outputs[0].logprobs is None
        assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
        # Check prompt logprobs are None
        assert results_logprobs_none[i].prompt_logprobs is None
410
411


412
def test_zero_logprobs(vllm_model, example_prompts):
413
    """Engine should return sampled token and prompt token logprobs
414

415
416
417
418
    Args:
      vllm_model: vLLM model fixture
      example_prompts: list of example prompts (test fixture)
    """
419
    max_tokens = 5
420

421
422
423
424
425
426
    sampling_params_logprobs_zero = SamplingParams(
        max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
    )
    results_logprobs_zero = vllm_model.llm.generate(
        example_prompts, sampling_params=sampling_params_logprobs_zero
    )
427

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    for i in range(len(results_logprobs_zero)):
        # Check that there is one sample logprob dict for each
        # sample token
        logprobs = results_logprobs_zero[i].outputs[0].logprobs
        prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
        sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
        prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
        assert logprobs is not None
        assert len(sampled_token_ids) == len(logprobs)
        assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
        # Check that there is one prompt logprob dict for each
        # prompt token
        assert prompt_logprobs is not None
        assert len(prompt_token_ids) == len(prompt_logprobs)


def test_all_logprobs(example_prompts):
445
    """Engine should return all vocabulary logprobs and prompt logprobs
446
447
448
449

    Args:
      example_prompts: list of example prompts (test fixture)
    """
450
451
452
453
454
455
456
457
    runner = VllmRunner(
        "facebook/opt-125m",
        max_logprobs=-1,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.15,
        max_model_len=256,
    )
458

459
460
461
462
463
464
    sampling_params_logprobs_all = SamplingParams(
        max_tokens=5, logprobs=-1, prompt_logprobs=-1
    )
    results_logprobs_all = runner.llm.generate(
        example_prompts, sampling_params=sampling_params_logprobs_all
    )
465
    vocab_size = runner.llm.llm_engine.model_config.get_vocab_size()
466

467
468
469
470
471
472
473
474
475
476
    for i in range(len(results_logprobs_all)):
        logprobs = results_logprobs_all[i].outputs[0].logprobs
        prompt_logprobs = results_logprobs_all[i].prompt_logprobs
        assert logprobs is not None
        for logprob in logprobs:
            assert len(logprob) == vocab_size
        assert prompt_logprobs is not None
        assert prompt_logprobs[0] is None
        for prompt_logprob in prompt_logprobs[1:]:
            assert len(prompt_logprob) == vocab_size
477
478


479
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
480
def test_logprobs_mode(logprobs_mode: LogprobsMode):
481
482
483
484
485
    """Test with LLM engine with different logprobs_mode.
    For logprobs, we should have non-positive values.
    For logits, we should expect at least one positive values.
    """
    from vllm import LLM
486

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    llm = LLM(
        "facebook/opt-125m",
        max_logprobs=5,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.05,
        max_model_len=16,
        logprobs_mode=logprobs_mode,
    )
    vllm_sampling_params = SamplingParams(logprobs=1)
    results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)

    total_token_with_logprobs = 0
    positive_values = 0
    for output in results[0].outputs:
        for logprobs in output.logprobs:
            for token_id in logprobs:
                logprob = logprobs[token_id]
                if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
                    assert logprob.logprob <= 0
                if logprob.logprob > 0:
                    positive_values = positive_values + 1
                total_token_with_logprobs = total_token_with_logprobs + 1
    assert total_token_with_logprobs >= len(results[0].outputs)
    if logprobs_mode in ("raw_logits", "processed_logits"):
        assert positive_values > 0
    del llm
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542


@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
@pytest.mark.parametrize(
    "model_setup",
    [
        pytest.param(
            (
                "eagle",
                "meta-llama/Llama-3.1-8B-Instruct",
                "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
            ),
            marks=large_gpu_mark(min_gb=32),
        ),
    ],
)
def test_spec_decode_logprobs(
    logprobs_mode: LogprobsMode,
    model_setup: tuple[str, str, str],
):
    """Spec decode logprobs should match those of the base model.

    Args:
        logprobs_mode: logprobs mode.
        model_setup: Spec decode method, base model name, and
        draft model name.
    """
    from vllm import LLM

543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    prompt = "Hello world"
    sampling_params = SamplingParams(
        temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
    )
    method, model_name, spec_model_name = model_setup
    max_model_len = 256

    # Run base LLM.
    ref_llm = LLM(
        model=model_name,
        max_logprobs=5,
        max_model_len=max_model_len,
        seed=42,
        logprobs_mode=logprobs_mode,
        gpu_memory_utilization=0.4,
    )
    ref_results = ref_llm.generate([prompt], sampling_params)
    # Collect logprobs outputs from reference LLM.
    ref_logprobs = []
    for output in ref_results[0].outputs:
        for logprobs in output.logprobs:
            for token_id in logprobs:
                ref_logprobs.append(logprobs[token_id])
    del ref_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Run spec decode LLM.
    spec_llm = LLM(
        model_name,
        speculative_config={
            "method": method,
            "model": spec_model_name,
            "num_speculative_tokens": 3,
            "max_model_len": max_model_len,
        },
        max_logprobs=5,
        max_model_len=max_model_len,
        seed=42,
        logprobs_mode=logprobs_mode,
        gpu_memory_utilization=0.4,
    )
    spec_results = spec_llm.generate([prompt], sampling_params)
    # Collect logprobs outputs from spec decode LLM.
    spec_logprobs = []
    for output in spec_results[0].outputs:
        for logprobs in output.logprobs:
            for token_id in logprobs:
                spec_logprobs.append(logprobs[token_id])
    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Per-token logprobs are expected to be the same.
    assert len(ref_logprobs) == len(spec_logprobs)
    for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
        assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
        assert ref_logprob.rank == spec_logprob.rank
        assert ref_logprob.decoded_token == spec_logprob.decoded_token