test_logprobs.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import itertools
5
import math
6
from collections.abc import Generator
7
from typing import get_args
8
9
10
11

import pytest
import torch

12
from tests.utils import large_gpu_mark
13
from tests.v1.sample.utils import (
14
15
    BatchLogprobsComposition,
    BatchLogprobsSpecType,
16
    assert_incr_detok_str_matches_non_incr_detok_str,
17
18
19
    compute_correct_cumulative_logprob,
    get_test_batch,
)
20
from vllm import SamplingParams
21
from vllm.config.model import LogprobsMode
22
from vllm.distributed import cleanup_dist_env_and_memory
23

24
from ...conftest import HfRunner, VllmRunner
25

26
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
27
28
DTYPE = "half"

29
30
31
32
NONE = BatchLogprobsComposition.NONE
SAMPLE = BatchLogprobsComposition.SAMPLE
PROMPT = BatchLogprobsComposition.PROMPT
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
33

34
35
36
37

@pytest.fixture(
    scope="module",
    # Parameterize APC
38
39
    params=[False, True],
)
40
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
41
    with vllm_runner(
42
43
44
45
46
47
48
49
        MODEL,
        dtype=DTYPE,
        max_logprobs=7,
        # Very small number of batched tokens to ensure
        # that we test chunking.
        max_num_batched_tokens=16,
        max_num_seqs=16,
        max_model_len=128,
50
        enable_chunked_prefill=True,
51
52
53
54
55
        enforce_eager=True,
        # TODO: enable this once we support it for
        # prompt logprobs.
        enable_prefix_caching=request.param,
        gpu_memory_utilization=0.4,  # up to 2 alive concurrently
56
57
58
59
60
    ) as vllm_model:
        yield vllm_model


@pytest.fixture(scope="module")
61
def hf_model(hf_runner) -> Generator[HfRunner, None, None]:
62
63
64
65
66
67
    with hf_runner(MODEL, dtype=DTYPE) as hf_model:
        yield hf_model


def _repeat_logprob_config(
    test_prompts,
68
69
    logprob_prompt_logprob_list: BatchLogprobsSpecType,
) -> BatchLogprobsSpecType:
70
    """Ensure each test prompt has a logprob config.
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    A logprob config specifies the optional (i.e.
    may-be-`None`) number of sample logprobs and
    the optional number of prompt logprobs.

    If more test prompts than logprob configs are
    provided, the provided logprob configs are
    tiled to match the number of test prompts.

    If fewer test prompts than logprob configs
    are provided, the list of logprob configs
    is truncated to match the number of test
    prompts.

    Otherwise, the list of logprob configs
    is returned as-is.

    Args:
      test_prompts: list of prompts under test
      logprob_prompt_logprob_list: list of
                            (optional num sample logprob,
                             optional num prompt logprob)
                             tuples
94

95
    Returns:
96
      list of
97
98
99
100
101
102
103
104
105
106
      (optional num sample logprob,optional num prompt logprob)
      tuples which is either identical to
      `logprob_prompt_logprob_list`, or else repeats
      `logprob_prompt_logprob_list` enough times to match the
      number of `test_prompts`, or else is truncated to match
      the number of `test_prompts`
    """
    num_test_prompts = len(test_prompts)
    # Make sure there is a logprobs configuration for each test prompt
    logprob_prompt_logprob_list = list(
107
108
        itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts)
    )
109
110
111
112
113
    # Now the number of prompts should match the number of sample params combos
    assert num_test_prompts == len(logprob_prompt_logprob_list)
    return logprob_prompt_logprob_list


114
115
116
117
118
119
120
def _run_and_validate(
    vllm_model: VllmRunner,
    test_prompts: list[str],
    vllm_sampling_params: SamplingParams,
    hf_logprobs: list[list[torch.Tensor]],
    hf_outputs: list[tuple[list[int], str]],
    logprob_prompt_logprob_list: BatchLogprobsSpecType,
121
    temperature: float,
122
123
    max_tokens: int,
    do_apc: bool,
124
) -> None:
125
    vllm_results = vllm_model.llm.generate(
126
127
        test_prompts, sampling_params=vllm_sampling_params
    )
128
129

    for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
130
131
        vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list
    ):
132
133
134
135
136
137
        # Extract request-level (prompt)logprobs config
        num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob

        # Test whether sampled token output is consistent between vLLM and HF
        # vLLM prompt+completion should match HF output
        if temperature == 0.0:
138
139
140
141
            assert (
                vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids
                == hf_output[0]
            )
142
143
        else:
            # Sampled tokens won't match if not greedy
144
145
146
147
            assert (
                vllm_result.prompt_token_ids
                == hf_output[0][: len(vllm_result.prompt_token_ids)]
            )
148
149
150
151
152
153
154
155

        # Validate sample logprobs
        if num_top_logprobs is not None:
            assert num_top_logprobs is not None
            # Confirm that the structure of the sample logprobs in the result is
            # correct
            assert vllm_result.outputs[0].logprobs is not None
            assert len(vllm_result.outputs[0].logprobs) == max_tokens
156
157
158
            for logprobs, token_id in zip(
                vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids
            ):
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
                assert logprobs is not None

                # Confirm that the output token appears among the logprobs
                assert token_id in logprobs
                token_in_topk = logprobs[token_id].rank <= num_top_logprobs

                # If the output token is not included in the top K
                # logprob, it can return 1 more data
                if token_in_topk and num_top_logprobs != 0:
                    assert len(logprobs) == num_top_logprobs
                else:
                    assert len(logprobs) == num_top_logprobs + 1

                if num_top_logprobs > 0:
                    # We should have an entry for each of the topk ranks
                    all_ranks = {lp.rank for lp in logprobs.values()}
175
                    assert all(r in all_ranks for r in range(1, num_top_logprobs + 1))
176
177

            output_text = vllm_result.outputs[0].text
178
            output_string_from_most_likely_tokens_lst: list[str] = []
179
180
181
            for top_logprobs in vllm_result.outputs[0].logprobs:
                top_logprob = next(iter(top_logprobs.values()))
                output_string_from_most_likely_tokens_lst.append(
182
183
                    top_logprob.decoded_token
                )
184
185

            output_string_from_most_likely_tokens = "".join(
186
187
                output_string_from_most_likely_tokens_lst
            )
188
            assert_incr_detok_str_matches_non_incr_detok_str(
189
190
                output_text,
                output_string_from_most_likely_tokens,
191
192
                "The output text from the top logprob for each token "
                "position should be the same as the output text in the "
193
194
                "result.",
            )
195
196
197
198
199
200
201
202
203
204
205

            # Compare vLLM sample logprobs to HF
            vllm_sample_logprobs = vllm_result.outputs[0].logprobs
            for i, top_logprobs in enumerate(vllm_sample_logprobs):
                for token_id, sample_logprob in top_logprobs.items():
                    if temperature == 0.0 or i == 0:
                        logprob = sample_logprob.logprob
                        torch.testing.assert_close(
                            logprob,
                            hf_logprob[i][-1][token_id].item(),
                            atol=1e-2,
206
207
208
209
210
211
                            rtol=1e-2,
                        )
                    assert isinstance(sample_logprob.decoded_token, str), (
                        "The token should be decoded by the time it is"
                        " returned to the user."
                    )
212
213
214
215
216
217
218
219
220

            # At this point we know the sample logprobs are correct for this
            # request. Validate that cumulative_logprob is actually the sum.
            # For each request, assert that the returned cumulative logprob
            # matches the correct value, which is computed below.
            torch.testing.assert_close(
                vllm_result.outputs[0].cumulative_logprob,
                compute_correct_cumulative_logprob(vllm_result.outputs[0]),
                atol=1e-6,
221
222
                rtol=1e-6,
            )
223
224
225
226
227
228
229
230
231
232
233
234
        else:
            # Logprobs disabled for this request; should be None
            assert vllm_result.outputs[0].logprobs is None

        # Validate prompt logprobs
        if num_top_prompt_logprobs is not None:
            # Confirm that structure of prompt logprobs in result is correct
            assert vllm_result.prompt_logprobs is not None
            # - The first prompt logprob is always None
            assert vllm_result.prompt_logprobs[0] is None
            # - Prompt logprobs are returned for all indices in
            #   the prompt
235
            assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids)
236
            for prompt_logprobs, prompt_token_id in zip(
237
238
                vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:]
            ):
239
240
241
242
                assert prompt_logprobs is not None

                # Confirm that the prompt token appears among the logprobs
                assert prompt_token_id in prompt_logprobs
243
244
245
                token_in_topk = (
                    prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs
                )
246
247
248
249
250
251
252
253
254
255
256

                # If the prompt token is not included in the top K
                # logprob, it can return 1 more data
                if token_in_topk and num_top_prompt_logprobs != 0:
                    assert len(prompt_logprobs) == num_top_prompt_logprobs
                else:
                    assert len(prompt_logprobs) == num_top_prompt_logprobs + 1

                if num_top_prompt_logprobs > 0:
                    # We should have an entry for each of the topk ranks
                    all_ranks = {lp.rank for lp in prompt_logprobs.values()}
257
258
259
                    assert all(
                        r in all_ranks for r in range(1, num_top_prompt_logprobs + 1)
                    )
260
261
262
263
264
265
266
267
268
269
270

            # Compare prompt logprobs to HF
            # The first prompt logprob is always None, so we compare it from
            # 1:.
            vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
            for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
                for token_id, logprob in vllm_prompt_logprob_dict.items():
                    torch.testing.assert_close(
                        logprob.logprob,
                        hf_logprob[0][i][token_id].item(),
                        atol=2e-2,
271
272
                        rtol=2e-2,
                    )
273
274
275
276
        else:
            assert vllm_result.prompt_logprobs is None


277
278
279
@pytest.mark.parametrize(
    "batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]
)
280
281
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
282
283
284
285
286
287
    hf_model,
    vllm_model,
    batch_logprobs_composition: BatchLogprobsComposition,
    temperature: float,
    example_prompts: list[str],
) -> None:
288
    """Test V1 Engine logprobs & prompt logprobs
289

290
291
292
293
294
295
296
297
298
299
300
301
    Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
    settings and validate that
    * The generated logprobs and prompt logprobs are consistent with the
      configuration settings, in terms of whether or not the logprobs
      (of either type) were requested and how many were requested
    * The generated logprobs are consistent with the generated tokens
    * The generated (prompt)logprobs are consistent with HuggingFace
      (prompt)logprobs, as a reference

    batch_logprobs_composition controls the logprobs configurations for
    requests in the batch under test.

302
303
304
305
    APC tests run two test iterations so that cache hits occur.

    To save time, only test one APC-enabled scenario
    (sample & prompt logprobs enabled, temperature>0.0).
306

307
    Args:
308
309
      hf_model: HuggingFace reference model fixture
      vllm_model: vLLM model fixture
310
      batch_logprobs_composition: logprobs configuration for test batch
311
312
      temperature: "temperature" sampling parameter
      example_prompts: example prompt fixture
313
    """
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
    if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT):
        # Skip some test-cases to save time.
        pytest.skip()
    test_prompts = example_prompts

    max_tokens = 5
    hf_outputs = hf_model.generate_greedy(
        test_prompts,
        max_tokens=max_tokens,
    )
    hf_logprobs = hf_model.generate_greedy_logprobs(
        test_prompts,
        max_tokens=max_tokens,
    )

    # Batch has mixed sample params
    # (different logprobs/prompt logprobs combos)
    logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)

    # Ensure that each test prompt has a logprob config for testing
    logprob_prompt_logprob_list = _repeat_logprob_config(
        test_prompts, logprob_prompt_logprob_list
    )
    # Generate SamplingParams
    vllm_sampling_params = [
        SamplingParams(
341
            max_tokens=max_tokens,
342
343
344
345
            logprobs=num_lp,
            prompt_logprobs=num_plp,
            temperature=temperature,
            seed=1984,
346
        )
347
348
349
350
351
352
353
354
355
356
357
        for num_lp, num_plp in logprob_prompt_logprob_list
    ]
    for _ in range(2 if do_apc else 1):
        _run_and_validate(
            vllm_model=vllm_model,
            test_prompts=test_prompts,
            vllm_sampling_params=vllm_sampling_params,
            hf_logprobs=hf_logprobs,
            hf_outputs=hf_outputs,
            logprob_prompt_logprob_list=logprob_prompt_logprob_list,
            temperature=temperature,
358
            max_tokens=max_tokens,
359
            do_apc=do_apc,
360
361
362
        )


363
def test_max_logprobs():
364
365
    """vLLM v1 engine should fail a request with `logprobs > max_logprobs`
    Should also fail for `prompt_logprobs > max_logprobs`
366
    APC should not matter as this test checks basic request validation.
367
    """
368
369
370
371
372
373
374
375
376
377
378
    runner = VllmRunner(
        "facebook/opt-125m",
        max_logprobs=1,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.15,
        max_model_len=256,
    )
    vllm_sampling_params = SamplingParams(logprobs=1)
    # should pass
    runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
379

380
381
382
    bad_sampling_params = SamplingParams(logprobs=2)
    with pytest.raises(ValueError):
        runner.generate(["Hello world"], sampling_params=bad_sampling_params)
383
384


385
def test_none_logprobs(vllm_model, example_prompts):
386
    """Engine should return `logprobs` and `prompt_logprobs` as `None`
387

388
389
390
391
    Args:
      vllm_model: vLLM model fixture
      example_prompts: list of example prompts (test fixture)
    """
392
    max_tokens = 5
393

394
395
396
397
398
399
400
401
402
403
    sampling_params_logprobs_none = SamplingParams(
        max_tokens=max_tokens,
        logprobs=None,
        prompt_logprobs=None,
        temperature=0.0,
    )
    results_logprobs_none = vllm_model.llm.generate(
        example_prompts,
        sampling_params=sampling_params_logprobs_none,
    )
404

405
406
407
408
409
410
    for i in range(len(results_logprobs_none)):
        # Check sample logprobs are None
        assert results_logprobs_none[i].outputs[0].logprobs is None
        assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
        # Check prompt logprobs are None
        assert results_logprobs_none[i].prompt_logprobs is None
411
412


413
def test_zero_logprobs(vllm_model, example_prompts):
414
    """Engine should return sampled token and prompt token logprobs
415

416
417
418
419
    Args:
      vllm_model: vLLM model fixture
      example_prompts: list of example prompts (test fixture)
    """
420
    max_tokens = 5
421

422
423
424
425
426
427
    sampling_params_logprobs_zero = SamplingParams(
        max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
    )
    results_logprobs_zero = vllm_model.llm.generate(
        example_prompts, sampling_params=sampling_params_logprobs_zero
    )
428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    for i in range(len(results_logprobs_zero)):
        # Check that there is one sample logprob dict for each
        # sample token
        logprobs = results_logprobs_zero[i].outputs[0].logprobs
        prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
        sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
        prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
        assert logprobs is not None
        assert len(sampled_token_ids) == len(logprobs)
        assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
        # Check that there is one prompt logprob dict for each
        # prompt token
        assert prompt_logprobs is not None
        assert len(prompt_token_ids) == len(prompt_logprobs)


def test_all_logprobs(example_prompts):
446
    """Engine should return all vocabulary logprobs and prompt logprobs
447
448
449
450

    Args:
      example_prompts: list of example prompts (test fixture)
    """
451
452
453
454
455
456
457
458
    runner = VllmRunner(
        "facebook/opt-125m",
        max_logprobs=-1,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.15,
        max_model_len=256,
    )
459

460
461
462
463
464
465
    sampling_params_logprobs_all = SamplingParams(
        max_tokens=5, logprobs=-1, prompt_logprobs=-1
    )
    results_logprobs_all = runner.llm.generate(
        example_prompts, sampling_params=sampling_params_logprobs_all
    )
466
    vocab_size = runner.llm.llm_engine.model_config.get_vocab_size()
467

468
469
470
471
472
473
474
475
476
477
    for i in range(len(results_logprobs_all)):
        logprobs = results_logprobs_all[i].outputs[0].logprobs
        prompt_logprobs = results_logprobs_all[i].prompt_logprobs
        assert logprobs is not None
        for logprob in logprobs:
            assert len(logprob) == vocab_size
        assert prompt_logprobs is not None
        assert prompt_logprobs[0] is None
        for prompt_logprob in prompt_logprobs[1:]:
            assert len(prompt_logprob) == vocab_size
478
479


480
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
481
def test_logprobs_mode(logprobs_mode: LogprobsMode):
482
483
484
485
486
    """Test with LLM engine with different logprobs_mode.
    For logprobs, we should have non-positive values.
    For logits, we should expect at least one positive values.
    """
    from vllm import LLM
487

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
    llm = LLM(
        "facebook/opt-125m",
        max_logprobs=5,
        enable_prefix_caching=False,
        # 2 other llms alive during whole session
        gpu_memory_utilization=0.05,
        max_model_len=16,
        logprobs_mode=logprobs_mode,
    )
    vllm_sampling_params = SamplingParams(logprobs=1)
    results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)

    total_token_with_logprobs = 0
    positive_values = 0
    for output in results[0].outputs:
        for logprobs in output.logprobs:
            for token_id in logprobs:
                logprob = logprobs[token_id]
                if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
                    assert logprob.logprob <= 0
                if logprob.logprob > 0:
                    positive_values = positive_values + 1
                total_token_with_logprobs = total_token_with_logprobs + 1
    assert total_token_with_logprobs >= len(results[0].outputs)
    if logprobs_mode in ("raw_logits", "processed_logits"):
        assert positive_values > 0
    del llm
515
516
517
518
519
520
521
522
523


@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
@pytest.mark.parametrize(
    "model_setup",
    [
        pytest.param(
            (
                "eagle",
524
525
                "meta-llama/Llama-3.2-1B-Instruct",
                "nm-testing/Llama3_2_1B_speculator.eagle3",
526
527
528
529
530
            ),
            marks=large_gpu_mark(min_gb=32),
        ),
    ],
)
531
@pytest.mark.parametrize("top_logprobs", [0, 3])
532
533
534
def test_spec_decode_logprobs(
    logprobs_mode: LogprobsMode,
    model_setup: tuple[str, str, str],
535
    top_logprobs: int,
536
537
538
539
540
541
542
543
544
545
):
    """Spec decode logprobs should match those of the base model.

    Args:
        logprobs_mode: logprobs mode.
        model_setup: Spec decode method, base model name, and
        draft model name.
    """
    from vllm import LLM

546
    prompt = "Hello world " * 50
547
    sampling_params = SamplingParams(
548
        temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    )
    method, model_name, spec_model_name = model_setup
    max_model_len = 256

    # Run base LLM.
    ref_llm = LLM(
        model=model_name,
        max_logprobs=5,
        max_model_len=max_model_len,
        seed=42,
        logprobs_mode=logprobs_mode,
        gpu_memory_utilization=0.4,
    )
    ref_results = ref_llm.generate([prompt], sampling_params)
    # Collect logprobs outputs from reference LLM.
    ref_logprobs = []
    for output in ref_results[0].outputs:
        for logprobs in output.logprobs:
            for token_id in logprobs:
                ref_logprobs.append(logprobs[token_id])
    del ref_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Run spec decode LLM.
    spec_llm = LLM(
        model_name,
        speculative_config={
            "method": method,
            "model": spec_model_name,
            "num_speculative_tokens": 3,
            "max_model_len": max_model_len,
        },
        max_logprobs=5,
        max_model_len=max_model_len,
        seed=42,
        logprobs_mode=logprobs_mode,
        gpu_memory_utilization=0.4,
587
588
589
        # Force prefill chunking
        enable_chunked_prefill=True,
        max_num_batched_tokens=32,
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
    )
    spec_results = spec_llm.generate([prompt], sampling_params)
    # Collect logprobs outputs from spec decode LLM.
    spec_logprobs = []
    for output in spec_results[0].outputs:
        for logprobs in output.logprobs:
            for token_id in logprobs:
                spec_logprobs.append(logprobs[token_id])
    del spec_llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    # Per-token logprobs are expected to be the same.
    assert len(ref_logprobs) == len(spec_logprobs)
    for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
605
606
607
        assert math.isclose(
            ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
        )
608
609
        assert ref_logprob.rank == spec_logprob.rank
        assert ref_logprob.decoded_token == spec_logprob.decoded_token
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685


def test_prompt_logprobs_with_chunking_and_preemption():
    """Test that prompt logprobs are correctly returned when using
    both chunked prefill and preemption.

    This test ensures that the num_prompt_logprobs tracking persists
    across preemptions and prefill chunks.
    """

    # Create prompts that will trigger chunking and preemption
    prompts = [
        "The following numbers of the sequence "
        + ", ".join(str(i) for i in range(10))
        + " are:",
        "In one word, the capital of France is ",
    ] + [f"Tell me about the number {i}: " for i in range(32)]

    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=40,
        min_tokens=20,
        prompt_logprobs=2,  # Request prompt logprobs
    )

    with VllmRunner(
        "Qwen/Qwen3-0.6B",
        max_model_len=512,
        enable_chunked_prefill=True,
        max_num_batched_tokens=48,  # Force prefill chunking
        num_gpu_blocks_override=32,  # Force preemptions
        disable_log_stats=False,
        gpu_memory_utilization=0.25,
    ) as vllm_model:
        metrics_before = vllm_model.llm.get_metrics()

        # Generate with prompt logprobs using generate_w_logprobs which
        # returns (output_ids, output_str, output_logprobs, prompt_logprobs)
        outputs = vllm_model.generate_w_logprobs(
            prompts, sampling_params=sampling_params, include_prompt_token_ids=True
        )

        # Verify that all outputs have prompt logprobs
        for i, output in enumerate(outputs):
            _, _, _, prompt_token_ids, prompt_logprobs = output
            assert prompt_logprobs is not None and len(prompt_logprobs) > 0, (
                f"Output {i} missing prompt logprobs"
            )
            assert len(prompt_logprobs) == len(prompt_token_ids), (
                "Unexpected number of prompt logprob positions"
            )

            # Each position should have the requested number of logprobs
            for pos, logprobs_dict in enumerate(prompt_logprobs):
                if logprobs_dict is not None:  # First token may be None
                    assert (
                        sampling_params.prompt_logprobs
                        <= len(logprobs_dict)
                        <= sampling_params.prompt_logprobs + 1
                    ), (
                        f"Output {i} position {pos} has {len(logprobs_dict)} "
                        f"logprobs, expected {sampling_params.prompt_logprobs}"
                    )

        # Check that we actually had preemptions
        metrics_after = vllm_model.llm.get_metrics()
        preemptions_before = next(
            (m.value for m in metrics_before if m.name == "vllm:num_preemptions"), 0
        )
        preemptions_after = next(
            (m.value for m in metrics_after if m.name == "vllm:num_preemptions"), 0
        )
        preemptions = preemptions_after - preemptions_before
        assert preemptions > 0, "Test did not trigger any preemptions"

        print(f"Test passed with {preemptions} preemptions")