test_batch_invariance.py 38.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import random

import pytest
import torch

from vllm import LLM, SamplingParams
11
12
from vllm.platforms import current_platform

13
14
15
16
17
hopper_only = pytest.mark.skipif(
    not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
    reason="Requires CUDA and Hopper (SM90)",
)

18
19
20
21

@pytest.fixture(autouse=True)
def enable_batch_invariant_mode():
    """Automatically enable batch invariant kernel overrides for all tests."""
22
23
    old_value = os.environ.get("VLLM_BATCH_INVARIANT")
    os.environ["VLLM_BATCH_INVARIANT"] = "1"
24
25
26
    yield
    # Restore original value after test
    if old_value is None:
27
        os.environ.pop("VLLM_BATCH_INVARIANT", None)
28
    else:
29
        os.environ["VLLM_BATCH_INVARIANT"] = old_value
30
31
32


def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    # Generate more realistic prompts that will actually produce varied tokens
    # Use a mix of common English text patterns

    prompt_templates = [
        # Question-answer style
        "Question: What is the capital of France?\nAnswer: The capital of France is",
        "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
        "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
        # Story/narrative style
        "Once upon a time in a distant galaxy, there lived",
        "The old man walked slowly down the street, remembering",
        "In the year 2157, humanity finally discovered",
        # Technical/code style
        "To implement a binary search tree in Python, first we need to",
        "The algorithm works by iterating through the array and",
        "Here's how to optimize database queries using indexing:",
        # Factual/informative style
        "The Renaissance was a period in European history that",
        "Climate change is caused by several factors including",
        "The human brain contains approximately 86 billion neurons which",
        # Conversational style
        "I've been thinking about getting a new laptop because",
        "Yesterday I went to the store and bought",
        "My favorite thing about summer is definitely",
57
58
    ]

59
60
61
    # Pick a random template
    base_prompt = random.choice(prompt_templates)

62
63
64
65
66
    if max_words < min_words:
        max_words = min_words
    target_words = random.randint(min_words, max_words)

    if target_words > 50:
67
68
69
        # For longer prompts, repeat context
        padding_text = (
            " This is an interesting topic that deserves more explanation. "
70
            * (target_words // 50)
71
72
73
74
        )
        base_prompt = base_prompt + padding_text

    return base_prompt
75
76


77
@hopper_only
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@pytest.mark.timeout(1000)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
    """
    Ensures that the same request (the 'needle' prompt) yields identical output
    whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
    using the high-level v1 LLM() API only (no manual batching).

    Strategy:
    - Create two LLM engines with identical config except max_num_seqs: 1 vs N.
    - Compute a baseline output for the needle prompt with the bs=1 engine.
    - For many trials, generate a batch (size N) where the needle appears at a
      random position among random filler prompts using the bs=N engine.
    - Track how many trials match vs mismatch, and report totals at the end.
      The test fails if any mismatches occur, but we still dump pass/fail
      counts.

    Notes:
    - Use seeded stochastic sampling with a fixed seed to test determinism.
    - Outputs are intentionally longer and sampled at higher temperature/top_p
97
      to produce a more random-sounding phrase, yet remain deterministic by
98
99
100
      seed.
    - Keep max_tokens and max_model_len bounded for speed and memory use.
    """
101
102
    seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
    random.seed(seed)
103
104
105
106
107

    # Allow overrides from environment (useful for CI tuning)
    # "facebook/opt-125m" is too small, doesn't reliably test determinism
    model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
    num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
108
109
110
111
    max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
    min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
    max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
    assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
112
113

    # Keep GPU memory usage low to avoid startup allocation failures.
114
115
    gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
    max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
116
117
118
119
120
121
122
123
124
125
126
127
128
129

    # Sampling parameters: longer outputs with a more random-sounding
    # continuation,but still deterministic due to fixed seed.
    temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
    top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
    max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))

    sampling = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        seed=20240919,
    )

130
    needle_prompt = "There once was a "
131
132
133
134
135
136
137

    llm_bs1 = None
    llm_bsN = None
    try:
        # Engine with bs=1 behavior
        llm_bs1 = LLM_with_max_seqs(
            model=model,
138
            max_num_seqs=max_batch_size,
139
140
141
142
143
144
145
146
147
148
149
150
151
            gpu_memory_utilization=gpu_mem_util,
            max_model_len=max_model_len,
        )

        # Baseline generation for the needle prompt alone.
        baseline_out = llm_bs1.generate([needle_prompt], sampling)
        assert len(baseline_out) == 1
        assert len(baseline_out[0].outputs) >= 1
        baseline_text = baseline_out[0].outputs[0].text

        # Engine with larger batch limit (e.g., 64)
        llm_bsN = LLM_with_max_seqs(
            model=model,
152
            max_num_seqs=max_batch_size,
153
154
155
156
157
158
159
            gpu_memory_utilization=gpu_mem_util,
            max_model_len=max_model_len,
        )

        mismatches = 0

        for trial in range(num_trials):
160
            # Create a batch of size `max_batch_size` and insert the needle at
161
162
            # a random index
            prompts: list[str] = []
163
            batch_size = random.randint(max_batch_size // 2, max_batch_size)
164
165
166
167
168
            needle_pos = random.randint(0, batch_size - 1)
            for i in range(batch_size):
                if i == needle_pos:
                    prompts.append(needle_prompt)
                else:
169
                    prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
170
171
172
173
174
175
176
177
178
179

            # Generate with the larger-batch engine
            outputs = llm_bsN.generate(prompts, sampling)
            # Find the needle output by position
            needle_output = outputs[needle_pos]
            assert needle_output.prompt == needle_prompt
            assert len(needle_output.outputs) >= 1
            text = needle_output.outputs[0].text

            if text != baseline_text:
180
                print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
181
182
183
184
                mismatches += 1

        passes = num_trials - mismatches
        # Dump how many passed vs failed
185
186
        print(
            f"[determinism] total={num_trials}, passed={passes}, "
187
            f"failed={mismatches}, max_batch_size={max_batch_size}"
188
        )
189
190
191
192

        if mismatches > 0:
            pytest.fail(
                f"Nondeterministic outputs detected: {mismatches} failed out "
193
                f"of {num_trials} trials (max_batch_size={max_batch_size})."
194
            )
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

    finally:
        # Ensure engines are shutdown to free GPU/VRAM across test sessions
        if llm_bs1 is not None:
            with contextlib.suppress(Exception):
                llm_bs1.shutdown()
        if llm_bsN is not None:
            with contextlib.suppress(Exception):
                llm_bsN.shutdown()


def _extract_step_logprobs(request_output):
    if getattr(request_output, "outputs", None):
        inner = request_output.outputs[0]
        if hasattr(inner, "logprobs") and inner.logprobs is not None:
            t = torch.tensor(
                [
                    inner.logprobs[i][tid].logprob
                    for i, tid in enumerate(inner.token_ids)
                ],
                dtype=torch.float32,
            )
217
            return t, inner.token_ids
218

219
    return None, None
220
221


222
@hopper_only
223
224
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
@pytest.mark.forked
225
226
227
228
229
230
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
    backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
    os.environ["VLLM_ATTENTION_BACKEND"] = backend

    seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
    random.seed(seed)
231
232
233
    model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
    tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

234
235
236
    # For batch invariance, disable custom all-reduce to ensure deterministic
    # all-reduce operations (custom all-reduce may not be deterministic)
    from vllm.model_executor.layers.batch_invariant import (
237
        vllm_is_batch_invariant,
238
239
    )

240
    disable_custom_ar = vllm_is_batch_invariant()
241
242
243
244
245
246

    if disable_custom_ar:
        print(f"\n{'=' * 80}")
        print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
        print(f"{'=' * 80}\n")

247
248
249
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,
250
        enable_prefix_caching=False,
251
252
253
        max_num_seqs=32,
        max_model_len=8192,
        dtype="bfloat16",  # not everything is supported
254
255
    )

256
257
    # Use more realistic prompts for better token generation
    prompts = [_random_prompt(10, 50) for i in range(32)]
258
259

    sp = SamplingParams(
260
        temperature=0.6,
261
262
263
264
265
266
267
        top_p=1.0,
        max_tokens=8,
        seed=1234,
        logprobs=5,
    )

    # BS=1: run prompts individually and collect logprobs per step.
268
269
270
271
    print("\n" + "=" * 80)
    print("STARTING BS=1 RUNS (each prompt individually)")
    print("=" * 80 + "\n")

272
    bs1_logprobs_per_prompt = []
273
274
275
    bs1_tokens_per_prompt = []
    for idx, p in enumerate(prompts):
        print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
276
277
        outs = llm.generate([p], sp, use_tqdm=False)
        assert len(outs) == 1
278
        step_logprobs, token_ids = _extract_step_logprobs(outs[0])
279
        if step_logprobs is None:
280
281
282
283
            pytest.skip(
                "Logits are not available on RequestOutput; "
                "enable logprobs return to run this test."
            )
284
        bs1_logprobs_per_prompt.append(step_logprobs)
285
286
        bs1_tokens_per_prompt.append(token_ids)
        print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
287

288
    # BS=N: run prompts in a batch and collect logprobs per step for each
289
    # prompt.
290
291
292
293
    print("\n" + "=" * 80)
    print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
    print("=" * 80 + "\n")

294
295
    outs_batched = llm.generate(prompts, sp, use_tqdm=False)
    assert len(outs_batched) == len(prompts)
296
    bsN_logprobs_per_prompt = []
297
298
299
300
301
302
303
    bsN_tokens_per_prompt = []

    print(f"\n[BS={len(prompts)}] Processing batched outputs...")
    for idx, o in enumerate(outs_batched):
        tokens = o.outputs[0].token_ids if o.outputs else "N/A"
        print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
        step_logprobs, token_ids = _extract_step_logprobs(o)
304
        if step_logprobs is None:
305
306
307
308
            pytest.skip(
                "Logits are not available on RequestOutput; "
                "enable logprobs return to run this test."
            )
309
        bsN_logprobs_per_prompt.append(step_logprobs)
310
        bsN_tokens_per_prompt.append(token_ids)
311

312
    # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
313
314
315
316
317
318
319
    failed_prompts = []
    for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
        zip(
            bs1_logprobs_per_prompt,
            bsN_logprobs_per_prompt,
            bs1_tokens_per_prompt,
            bsN_tokens_per_prompt,
320
        )
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    ):
        if len(logprobs_bs1) != len(logprobs_bsN):
            reason = (
                f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
                f"vs {len(logprobs_bsN)} (BS=N)"
            )
            failed_prompts.append(
                {
                    "prompt_idx": i,
                    "step": "all",
                    "reason": reason,
                    "prompt_preview": prompts[i][:100],
                    "bs1_tokens": tokens_bs1,
                    "bsN_tokens": tokens_bsN,
                }
            )
            continue

        # Check if tokens match first
        if tokens_bs1 != tokens_bsN:
            failed_prompts.append(
                {
                    "prompt_idx": i,
                    "step": "sampling",
                    "reason": "Different tokens sampled",
                    "prompt_preview": prompts[i][:100],
                    "bs1_tokens": tokens_bs1,
                    "bsN_tokens": tokens_bsN,
                    "bs1_all_logprobs": [
                        logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
                    ],
                    "bsN_all_logprobs": [
                        logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
                    ],
                }
            )
            continue

359
        for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
            if a.shape != b.shape:
                failed_prompts.append(
                    {
                        "prompt_idx": i,
                        "step": t,
                        "reason": f"Shape mismatch: {a.shape} vs {b.shape}",
                        "prompt_preview": prompts[i][:100],
                        "bs1_tokens": tokens_bs1,
                        "bsN_tokens": tokens_bsN,
                    }
                )
                break

            if not torch.equal(a, b):
                max_diff = torch.abs(a - b).max().item()
                # Print which token failed
                print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
                bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
                bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
                print(f"  Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
                print(f"  BS=1 logprob: {a.tolist()}")
                print(f"  BS=N logprob: {b.tolist()}")
                failed_prompts.append(
                    {
                        "prompt_idx": i,
                        "step": t,
                        "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
                        "prompt_preview": prompts[i][:100],
                        "bs1_tokens": tokens_bs1,
                        "bsN_tokens": tokens_bsN,
                        "bs1_all_logprobs": [
                            logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
                        ],
                        "bsN_all_logprobs": [
                            logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
                        ],
                    }
                )
                break

    # Print summary of all failures
    if failed_prompts:
        print(f"\n{'=' * 80}")
        fail_msg = (
            f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
            f"{len(prompts)} prompts failed"
        )
        print(fail_msg)
        print(f"{'=' * 80}")
        for fail in failed_prompts:
            print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):")
            print(f"  Reason: {fail['reason']}")
            print(f"  Preview: {fail['prompt_preview']}...")

            # Always show the tokens
            if "bs1_tokens" in fail:
                print(f"  BS=1 tokens: {fail['bs1_tokens']}")
            if "bsN_tokens" in fail:
                print(f"  BS=N tokens: {fail['bsN_tokens']}")

            if "bs1_all_logprobs" in fail:
                print(f"  BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
                for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
                    print(f"    Step {step_idx}: {logprobs}")
                print(f"  BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
                for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
                    print(f"    Step {step_idx}: {logprobs}")
        print(f"{'=' * 80}\n")

        # Fail the test with summary
        msg = (
            f"Batch invariance violated in {len(failed_prompts)}/"
            f"{len(prompts)} prompts. See output above for details."
        )
        pytest.fail(msg)


437
@hopper_only
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def test_simple_generation():
    """
    Simple test that runs the model with a basic prompt and prints the output.
    Useful for quick smoke testing and debugging.
    """
    model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")

    llm = LLM(
        model=model,
        max_num_seqs=1,
        tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
        enforce_eager=True,
        gpu_memory_utilization=0.9,
        max_model_len=2048,
        dtype="bfloat16",
        enable_prefix_caching=False,
    )

    prompt = "the capital of france is"
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=20,
    )

    print(f"\n{'=' * 80}")
    print("Running simple generation test")
    print(f"Prompt: '{prompt}'")
    print(f"{'=' * 80}\n")

    try:
        outputs = llm.generate([prompt], sampling_params)

        assert len(outputs) == 1
        output_text = outputs[0].outputs[0].text

        print(f"Output: '{output_text}'")
        print(f"\n{'=' * 80}")
        print(f"Full completion: '{prompt}{output_text}'")
        print(f"{'=' * 80}\n")

    finally:
        with contextlib.suppress(Exception):
            llm.shutdown()


483
@hopper_only
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
@pytest.mark.forked
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
    """
    This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
    It DISABLES batch invariance mode and expects to see non-deterministic behavior
    between BS=1 and BS=N runs. This demonstrates that batch invariance is actually
    doing something useful.

    The test will PASS if we detect differences (proving batch invariance matters).
    The test will FAIL if everything matches (suggesting batch invariance isn't needed).
    """
    backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
    os.environ["VLLM_ATTENTION_BACKEND"] = backend

    # CRITICAL: Disable batch invariance for this test
500
501
    old_value = os.environ.get("VLLM_BATCH_INVARIANT")
    os.environ["VLLM_BATCH_INVARIANT"] = "0"
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521

    try:
        seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
        random.seed(seed)
        model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
        tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

        print(f"\n{'=' * 80}")
        print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
        print(f"{'=' * 80}\n")

        llm = LLM(
            model=model_name,
            tensor_parallel_size=tp_size,
            enable_prefix_caching=False,
            max_num_seqs=32,
            max_model_len=8192,
            dtype="bfloat16",
        )

522
523
524
525
526
527
528
529
530
531
532
533
534
535
        # build ragged prompts to change shapes significantly across BS=1 vs BS=N
        long_min = int(os.getenv("VLLM_MIN_PROMPT", "768"))
        long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
        prompts: list[str] = []
        options = [
            (max(long_min, 1536), max(long_max, 3072)),  # very long
            (max(1024, long_min), max(2048, long_max)),  # long
            (256, 512),  # mid
            (10, 20),  # short
        ]

        for _ in range(32):
            lo, hi = random.choice(options)
            prompts.append(_random_prompt(lo, hi))
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598

        sp = SamplingParams(
            temperature=0.6,
            top_p=1.0,
            max_tokens=8,
            seed=1234,
            logprobs=5,
        )

        # BS=1: run prompts individually and collect logprobs per step.
        print("\n" + "=" * 80)
        print("STARTING BS=1 RUNS (each prompt individually)")
        print("=" * 80 + "\n")

        bs1_logprobs_per_prompt = []
        bs1_tokens_per_prompt = []
        for idx, p in enumerate(prompts):
            print(
                f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
            )
            outs = llm.generate([p], sp, use_tqdm=False)
            assert len(outs) == 1
            step_logprobs, token_ids = _extract_step_logprobs(outs[0])
            if step_logprobs is None:
                pytest.skip(
                    "Logits are not available on RequestOutput; "
                    "enable logprobs return to run this test."
                )
            bs1_logprobs_per_prompt.append(step_logprobs)
            bs1_tokens_per_prompt.append(token_ids)
            print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")

        # BS=N: run prompts in a batch and collect logprobs per step for each prompt.
        print("\n" + "=" * 80)
        print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
        print("=" * 80 + "\n")

        outs_batched = llm.generate(prompts, sp, use_tqdm=False)
        assert len(outs_batched) == len(prompts)
        bsN_logprobs_per_prompt = []
        bsN_tokens_per_prompt = []

        print(f"\n[BS={len(prompts)}] Processing batched outputs...")
        for idx, o in enumerate(outs_batched):
            tokens = o.outputs[0].token_ids if o.outputs else "N/A"
            print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
            step_logprobs, token_ids = _extract_step_logprobs(o)
            if step_logprobs is None:
                pytest.skip(
                    "Logits are not available on RequestOutput; "
                    "enable logprobs return to run this test."
                )
            bsN_logprobs_per_prompt.append(step_logprobs)
            bsN_tokens_per_prompt.append(token_ids)

        # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
        differences_found = []
        for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
            zip(
                bs1_logprobs_per_prompt,
                bsN_logprobs_per_prompt,
                bs1_tokens_per_prompt,
                bsN_tokens_per_prompt,
599
            )
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        ):
            if len(logprobs_bs1) != len(logprobs_bsN):
                reason = (
                    f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
                    f"vs {len(logprobs_bsN)} (BS=N)"
                )
                differences_found.append(
                    {
                        "prompt_idx": i,
                        "step": "all",
                        "reason": reason,
                        "prompt_preview": prompts[i][:100],
                        "bs1_tokens": tokens_bs1,
                        "bsN_tokens": tokens_bsN,
                    }
                )
                continue

            # Check if tokens match first
            if tokens_bs1 != tokens_bsN:
                differences_found.append(
                    {
                        "prompt_idx": i,
                        "step": "sampling",
                        "reason": "Different tokens sampled",
                        "prompt_preview": prompts[i][:100],
                        "bs1_tokens": tokens_bs1,
                        "bsN_tokens": tokens_bsN,
                    }
                )
                continue

            for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
                if a.shape != b.shape:
                    differences_found.append(
                        {
                            "prompt_idx": i,
                            "step": t,
                            "reason": f"Shape mismatch: {a.shape} vs {b.shape}",
                            "prompt_preview": prompts[i][:100],
                            "bs1_tokens": tokens_bs1,
                            "bsN_tokens": tokens_bsN,
                        }
                    )
                    break

                if not torch.equal(a, b):
                    max_diff = torch.abs(a - b).max().item()
                    print(
                        f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
                        f"Token {t}: max_diff={max_diff:.6e}"
                    )
                    bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
                    bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
                    print(f"  Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
                    print(f"  BS=1 logprob: {a.tolist()}")
                    print(f"  BS=N logprob: {b.tolist()}")
                    differences_found.append(
                        {
                            "prompt_idx": i,
                            "step": t,
                            "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
                            "prompt_preview": prompts[i][:100],
                            "bs1_tokens": tokens_bs1,
                            "bsN_tokens": tokens_bsN,
                        }
                    )
                    break

        # Print summary
        print(f"\n{'=' * 80}")
        if differences_found:
            success_msg = (
                f"✓ SUCCESS: Batch invariance is doing something! "
                f"Found {len(differences_found)}/{len(prompts)} prompts "
                f"with differences when batch invariance was DISABLED."
676
            )
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
            print(success_msg)
            print(f"{'=' * 80}")
            for diff in differences_found:
                print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
                print(f"  Reason: {diff['reason']}")
                print(f"  Preview: {diff['prompt_preview']}...")
                if "bs1_tokens" in diff:
                    print(f"  BS=1 tokens: {diff['bs1_tokens']}")
                if "bsN_tokens" in diff:
                    print(f"  BS=N tokens: {diff['bsN_tokens']}")
            print(f"{'=' * 80}\n")
            # Test PASSES because we found differences (batch invariance matters!)
            return
        else:
            # Test FAILS because everything matched even without batch invariance
            fail_msg = (
                f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
                f"between BS=1 and BS=N even with batch invariance DISABLED. "
                f"This suggests batch invariance might not be necessary, "
                f"or the test needs more sensitive prompts."
            )
            print(fail_msg)
            print(f"{'=' * 80}\n")
            pytest.fail(fail_msg)

    finally:
        # Restore original value
        if old_value is None:
705
            os.environ.pop("VLLM_BATCH_INVARIANT", None)
706
        else:
707
            os.environ["VLLM_BATCH_INVARIANT"] = old_value
708
709


710
@hopper_only
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.forked
def test_decode_logprobs_match_prefill_logprobs(backend):
    """
    Test that verifies decode logprobs match prefill logprobs.

    For each decoded token at position i:
    1. Run decode to generate N tokens and collect their logprobs
    2. For each position i in [0, N):
       - Take prefix = prompt + tokens[0:i]
       - Run prefill(prefix + tokens[i]) to get logprob of tokens[i]
       - Verify prefill logprob matches decode logprob bitwise

    This ensures that the logprobs from decode are consistent with what
    we would get if we ran prefill on each prefix.
    """
    backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
    os.environ["VLLM_ATTENTION_BACKEND"] = backend

    seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
    random.seed(seed)
    model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
    tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

    from vllm.model_executor.layers.batch_invariant import (
736
        vllm_is_batch_invariant,
737
738
    )

739
    disable_custom_ar = vllm_is_batch_invariant()
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976

    if disable_custom_ar:
        print(f"\n{'=' * 80}")
        print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
        print(f"{'=' * 80}\n")

    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,
        enable_prefix_caching=False,
        max_num_seqs=32,
        max_model_len=8192,
        dtype="bfloat16",
    )

    # Use a few test prompts
    num_test_prompts = int(os.getenv("VLLM_DECODE_PREFILL_NUM_PROMPTS", "4"))
    prompts = [_random_prompt(10, 50) for _ in range(num_test_prompts)]

    # Generate longer sequences to test multiple decode steps
    max_tokens = int(os.getenv("VLLM_DECODE_PREFILL_MAX_TOKENS", "16"))

    sp = SamplingParams(
        temperature=0.0,  # Greedy for determinism
        max_tokens=max_tokens,
        logprobs=5,
    )

    print("\n" + "=" * 80)
    print("STEP 1: Running decode to generate tokens and collect logprobs")
    print("=" * 80 + "\n")

    # Step 1: Run decode and collect logprobs
    decode_outputs = llm.generate(prompts, sp, use_tqdm=False)

    failed_comparisons = []

    for prompt_idx, (prompt, decode_output) in enumerate(zip(prompts, decode_outputs)):
        print(f"\n[Prompt {prompt_idx}] Testing: {prompt[:80]}...")

        # Extract decode logprobs and tokens
        decode_logprobs, token_ids = _extract_step_logprobs(decode_output)
        if decode_logprobs is None:
            pytest.skip(
                "Logprobs are not available on RequestOutput; "
                "enable logprobs return to run this test."
            )

        print(f"[Prompt {prompt_idx}] Generated {len(token_ids)} tokens: {token_ids}")
        print(f"[Prompt {prompt_idx}] Decode logprobs: {decode_logprobs.tolist()}")

        # Step 2: For each token position, run prefill and compare
        print(f"\n[Prompt {prompt_idx}] Verifying each token via prefill...")

        for token_idx in range(len(token_ids)):
            # Construct the prefix up to (but not including) this token
            current_token = token_ids[token_idx]

            # We need to detokenize to get the text prefix
            # For this, we'll use the tokenizer from the LLM
            # However, the LLM API doesn't expose tokenizer easily, so we'll
            # construct the prefix by decoding from the original prompt

            # Get text up to this point by using the output text
            # This is approximate but should work for verification
            if token_idx == 0:
                prefix_prompt = prompt
            else:
                # Use the partial output text up to this token
                # We'll need to construct this from the full output
                prefix_output = decode_output.outputs[0]
                # Get the text for tokens 0 to token_idx-1
                # Unfortunately, we don't have per-token text, so we'll use
                # a different approach: run prefill with prompt + tokens[0:token_idx]

                # Actually, we need to get the actual text. Let's use a workaround:
                # Run a generation with max_tokens = token_idx to get that prefix
                prefix_sp = SamplingParams(
                    temperature=0.0,
                    max_tokens=token_idx,
                    logprobs=1,
                )
                prefix_output = llm.generate([prompt], prefix_sp, use_tqdm=False)[0]
                prefix_prompt = prompt + prefix_output.outputs[0].text

            # Now run prefill with max_tokens=1 to get the logprob of the next token
            prefill_sp = SamplingParams(
                temperature=0.0,
                max_tokens=1,
                logprobs=5,
            )

            print(
                f"  [Token {token_idx}] Running prefill for prefix "
                f"(len={len(prefix_prompt)})..."
            )
            prefill_output = llm.generate([prefix_prompt], prefill_sp, use_tqdm=False)[
                0
            ]
            prefill_logprobs, prefill_token_ids = _extract_step_logprobs(prefill_output)

            if prefill_logprobs is None:
                print(f"  [Token {token_idx}] Warning: No prefill logprobs available")
                continue

            # The first token from prefill should match the current token
            prefill_token = prefill_token_ids[0]
            prefill_logprob = prefill_logprobs[0].item()
            decode_logprob = decode_logprobs[token_idx].item()

            print(
                f"  [Token {token_idx}] Decode token: {current_token}, "
                f"logprob: {decode_logprob:.8f}"
            )
            print(
                f"  [Token {token_idx}] Prefill token: {prefill_token}, "
                f"logprob: {prefill_logprob:.8f}"
            )

            # Check if tokens match
            if current_token != prefill_token:
                failed_comparisons.append(
                    {
                        "prompt_idx": prompt_idx,
                        "token_idx": token_idx,
                        "reason": "Token mismatch",
                        "decode_token": current_token,
                        "prefill_token": prefill_token,
                        "decode_logprob": decode_logprob,
                        "prefill_logprob": prefill_logprob,
                        "prompt_text": prompt[:100],
                        "prefix_text": prefix_prompt[:100],
                    }
                )
                print(f"  [Token {token_idx}] ✗ TOKEN MISMATCH!")
                continue

            # Check if logprobs match bitwise
            if decode_logprob != prefill_logprob:
                diff = abs(decode_logprob - prefill_logprob)
                failed_comparisons.append(
                    {
                        "prompt_idx": prompt_idx,
                        "token_idx": token_idx,
                        "reason": "Logprob mismatch",
                        "decode_token": current_token,
                        "prefill_token": prefill_token,
                        "decode_logprob": decode_logprob,
                        "prefill_logprob": prefill_logprob,
                        "diff": diff,
                        "prompt_text": prompt[:100],
                        "prefix_text": prefix_prompt[:100],
                        "decode_all_tokens": token_ids,
                        "decode_all_logprobs": decode_logprobs.tolist(),
                    }
                )
                print(f"  [Token {token_idx}] ✗ LOGPROB MISMATCH! diff={diff:.8e}")
            else:
                print(f"  [Token {token_idx}] ✓ Match (bitwise equal)")

    # Print summary
    print(f"\n{'=' * 80}")
    if failed_comparisons:
        print(f"DECODE-PREFILL MISMATCH: {len(failed_comparisons)} failures detected")
        print(f"{'=' * 80}")

        # Group failures by prompt for better readability
        failures_by_prompt: dict[int, list[dict]] = {}
        for fail in failed_comparisons:
            pid = fail["prompt_idx"]
            if pid not in failures_by_prompt:
                failures_by_prompt[pid] = []
            failures_by_prompt[pid].append(fail)

        for prompt_idx, failures in failures_by_prompt.items():
            print(f"\n{'=' * 80}")
            print(f"PROMPT {prompt_idx}: {failures[0]['prompt_text']}...")
            print(f"{'=' * 80}")
            print(f"Total failures for this prompt: {len(failures)}")

            # Show where mismatches occur (which token positions)
            mismatch_positions = [f["token_idx"] for f in failures]
            print(f"Mismatch at token positions: {mismatch_positions}")

            # Show first few failures in detail
            for i, fail in enumerate(failures[:5]):  # Show first 5 failures per prompt
                print(f"\n  [Failure {i + 1}] Token position {fail['token_idx']}:")
                print(f"    Reason: {fail['reason']}")
                print(f"    Prefix text: '{fail['prefix_text']}...'")
                print(
                    f"    Decode:  token={fail['decode_token']}, "
                    f"logprob={fail['decode_logprob']:.10f}"
                )
                print(
                    f"    Prefill: token={fail['prefill_token']}, "
                    f"logprob={fail['prefill_logprob']:.10f}"
                )
                if "diff" in fail:
                    print(f"    Difference: {fail['diff']:.10e}")
                    # Show in hex to see bitwise difference
                    import struct

                    decode_hex = struct.pack("f", fail["decode_logprob"]).hex()
                    prefill_hex = struct.pack("f", fail["prefill_logprob"]).hex()
                    print(f"    Decode logprob (hex):  0x{decode_hex}")
                    print(f"    Prefill logprob (hex): 0x{prefill_hex}")

                # If we have all tokens/logprobs, show the context
                if "decode_all_tokens" in fail and "decode_all_logprobs" in fail:
                    token_idx = fail["token_idx"]
                    all_tokens = fail["decode_all_tokens"]
                    all_logprobs = fail["decode_all_logprobs"]

                    # Show context: 2 tokens before and after
                    start = max(0, token_idx - 2)
                    end = min(len(all_tokens), token_idx + 3)

                    print(f"    Context (tokens {start} to {end - 1}):")
                    for j in range(start, end):
                        marker = " <-- MISMATCH" if j == token_idx else ""
                        print(
                            f"      [{j}] token={all_tokens[j]}, "
                            f"logprob={all_logprobs[j]:.8f}{marker}"
                        )

            if len(failures) > 5:
                print(f"\n  ... and {len(failures) - 5} more failures for this prompt")

        print(f"\n{'=' * 80}\n")

        pytest.fail(
            f"Decode logprobs do not match prefill logprobs: "
            f"{len(failed_comparisons)} mismatches found."
        )
    else:
        print("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!")
        print(f"{'=' * 80}\n")
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993


def LLM_with_max_seqs(
    model: str,
    max_num_seqs: int,
    gpu_memory_utilization: float,
    max_model_len: int,
) -> LLM:
    """
    Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
    using the high-level v1 LLM API, while constraining memory usage.
    """
    return LLM(
        model=model,
        max_num_seqs=max_num_seqs,
        gpu_memory_utilization=gpu_memory_utilization,
        max_model_len=max_model_len,
994
        dtype="bfloat16",
995
        tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
996
        enable_prefix_caching=False,
997
        enforce_eager=True,
998
999
        # Enable for MOE models
        # enable_expert_parallel=True,
1000
    )