test_batch_invariance.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import random
import string

import pytest
import torch

from vllm import LLM, SamplingParams


def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
    # Lightweight random prompt generator to vary prompt lengths and content.
    vocab = [
        "alpha",
        "bravo",
        "charlie",
        "delta",
        "echo",
        "foxtrot",
        "golf",
        "hotel",
        "india",
        "juliet",
        "kilo",
        "lima",
        "mike",
        "november",
        "oscar",
        "papa",
        "quebec",
        "romeo",
        "sierra",
        "tango",
        "uniform",
        "victor",
        "whiskey",
        "xray",
        "yankee",
        "zulu",
    ]
    n = random.randint(min_words, max_words)
    words = random.choices(vocab, k=n)

    # Add some noise and punctuation variability
    if random.random() < 0.5:
        words[0] = words[0].capitalize()
    if random.random() < 0.2:
        words.append("".join(random.choices(string.ascii_lowercase, k=5)))
    punct = random.choice([".", "?", "!", "...", ""])
    return " ".join(words) + punct


@pytest.mark.timeout(1000)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
    """
    Ensures that the same request (the 'needle' prompt) yields identical output
    whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
    using the high-level v1 LLM() API only (no manual batching).

    Strategy:
    - Create two LLM engines with identical config except max_num_seqs: 1 vs N.
    - Compute a baseline output for the needle prompt with the bs=1 engine.
    - For many trials, generate a batch (size N) where the needle appears at a
      random position among random filler prompts using the bs=N engine.
    - Track how many trials match vs mismatch, and report totals at the end.
      The test fails if any mismatches occur, but we still dump pass/fail
      counts.

    Notes:
    - Use seeded stochastic sampling with a fixed seed to test determinism.
    - Outputs are intentionally longer and sampled at higher temperature/top_p
      to produce a more random-sounding phrase, yet remain deterministic by 
      seed.
    - Keep max_tokens and max_model_len bounded for speed and memory use.
    """
79
80
    seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
    random.seed(seed)
81
82
83
84
85

    # Allow overrides from environment (useful for CI tuning)
    # "facebook/opt-125m" is too small, doesn't reliably test determinism
    model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
    num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
86
87
88
89
    max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
    min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
    max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
    assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
90
91

    # Keep GPU memory usage low to avoid startup allocation failures.
92
93
    gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
    max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))

    # Sampling parameters: longer outputs with a more random-sounding
    # continuation,but still deterministic due to fixed seed.
    temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
    top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
    max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))

    sampling = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        seed=20240919,
    )

    needle_prompt = ("There once was a ")

    llm_bs1 = None
    llm_bsN = None
    try:
        # Engine with bs=1 behavior
        llm_bs1 = LLM_with_max_seqs(
            model=model,
117
            max_num_seqs=128,
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            gpu_memory_utilization=gpu_mem_util,
            max_model_len=max_model_len,
            swap_space=swap_space_gb,
        )

        # Baseline generation for the needle prompt alone.
        baseline_out = llm_bs1.generate([needle_prompt], sampling)
        assert len(baseline_out) == 1
        assert len(baseline_out[0].outputs) >= 1
        baseline_text = baseline_out[0].outputs[0].text

        # Engine with larger batch limit (e.g., 64)
        llm_bsN = LLM_with_max_seqs(
            model=model,
132
            max_num_seqs=128,
133
134
135
136
137
138
139
140
            gpu_memory_utilization=gpu_mem_util,
            max_model_len=max_model_len,
            swap_space=swap_space_gb,
        )

        mismatches = 0

        for trial in range(num_trials):
141
            # Create a batch of size `max_batch_size` and insert the needle at
142
143
            # a random index
            prompts: list[str] = []
144
            batch_size = random.randint(max_batch_size // 2, max_batch_size)
145
146
147
148
149
            needle_pos = random.randint(0, batch_size - 1)
            for i in range(batch_size):
                if i == needle_pos:
                    prompts.append(needle_prompt)
                else:
150
151
                    prompts.append(
                        _random_prompt(min_random_prompt, max_random_prompt))
152
153
154
155
156
157
158
159
160
161

            # Generate with the larger-batch engine
            outputs = llm_bsN.generate(prompts, sampling)
            # Find the needle output by position
            needle_output = outputs[needle_pos]
            assert needle_output.prompt == needle_prompt
            assert len(needle_output.outputs) >= 1
            text = needle_output.outputs[0].text

            if text != baseline_text:
162
163
                print(
                    f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
164
165
166
167
168
                mismatches += 1

        passes = num_trials - mismatches
        # Dump how many passed vs failed
        print(f"[determinism] total={num_trials}, passed={passes}, "
169
              f"failed={mismatches}, max_batch_size={max_batch_size}")
170
171
172
173

        if mismatches > 0:
            pytest.fail(
                f"Nondeterministic outputs detected: {mismatches} failed out "
174
                f"of {num_trials} trials (max_batch_size={max_batch_size}).")
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    finally:
        # Ensure engines are shutdown to free GPU/VRAM across test sessions
        if llm_bs1 is not None:
            with contextlib.suppress(Exception):
                llm_bs1.shutdown()
        if llm_bsN is not None:
            with contextlib.suppress(Exception):
                llm_bsN.shutdown()


def _extract_step_logprobs(request_output):
    if getattr(request_output, "outputs", None):
        inner = request_output.outputs[0]
        if hasattr(inner, "logprobs") and inner.logprobs is not None:
            t = torch.tensor(
                [
                    inner.logprobs[i][tid].logprob
                    for i, tid in enumerate(inner.token_ids)
                ],
                dtype=torch.float32,
            )
            return t

    return None


@pytest.mark.skipif(
    not torch.cuda.is_available(),
    reason="Requires CUDA to match production inference path.",
)
206
207
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
208

209
210
211
212
213
    backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
    os.environ["VLLM_ATTENTION_BACKEND"] = backend

    seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
    random.seed(seed)
214
215
216
217
218
219
220
221
222
223
224
225
226
    model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
    tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

    # Force float32 to avoid precision-induced differences.
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,
        enforce_eager=True,  # helps reduce nondeterminism from some backends
    )

    prompts = [
        "The capital of France is",
        "The capital of Germany is",
227
228
229
230
231
        _random_prompt(10, 1024),
        _random_prompt(10, 1024),
        _random_prompt(10, 1024),
        _random_prompt(10, 1024),
        _random_prompt(10, 1024),
232
233
234
    ]

    sp = SamplingParams(
235
        temperature=0.6,
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        top_p=1.0,
        max_tokens=8,
        # Seed shouldn't matter at temperature=0, but keeping it stable anyway.
        seed=1234,
        logprobs=5,
    )

    # BS=1: run prompts individually and collect logprobs per step.
    bs1_logprobs_per_prompt = []
    for p in prompts:
        outs = llm.generate([p], sp, use_tqdm=False)
        assert len(outs) == 1
        step_logprobs = _extract_step_logprobs(outs[0])
        if step_logprobs is None:
            pytest.skip("Logits are not available on RequestOutput; "
                        "enable logprobs return to run this test.")
        bs1_logprobs_per_prompt.append(step_logprobs)

254
    # BS=N: run prompts in a batch and collect logprobs per step for each
255
256
257
    # prompt.
    outs_batched = llm.generate(prompts, sp, use_tqdm=False)
    assert len(outs_batched) == len(prompts)
258
    bsN_logprobs_per_prompt = []
259
260
261
262
263
    for o in outs_batched:
        step_logprobs = _extract_step_logprobs(o)
        if step_logprobs is None:
            pytest.skip("Logits are not available on RequestOutput; "
                        "enable logprobs return to run this test.")
264
        bsN_logprobs_per_prompt.append(step_logprobs)
265

266
267
268
269
    # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
    for i, (logprobs_bs1, logprobs_bsN) in enumerate(
            zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)):
        assert len(logprobs_bs1) == len(logprobs_bsN), (
270
            f"Different number of generation steps for prompt index {i}: "
271
272
            f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)")
        for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            assert a.shape == b.shape, (
                f"Logits shape mismatch at prompt {i}, step {t}: "
                f"{a.shape} vs {b.shape}")
            # Bitwise exact equality.
            assert torch.equal(
                a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} "
                        f"(dtype={a.dtype}, shape={a.shape}).")


def LLM_with_max_seqs(
    model: str,
    max_num_seqs: int,
    gpu_memory_utilization: float,
    max_model_len: int,
    swap_space: int,
) -> LLM:
    """
    Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
    using the high-level v1 LLM API, while constraining memory usage.
    """
    return LLM(
        model=model,
        max_num_seqs=max_num_seqs,
        # Constrain GPU memory pool so test can run even on busy GPUs.
        gpu_memory_utilization=gpu_memory_utilization,
        # Keep KV cache footprint small while allowing longer outputs.
        max_model_len=max_model_len,
        # Allow some CPU offload if needed.
        swap_space=swap_space,
        # Keep things lean and CI-friendly.
        dtype="float16",
        # Single-GPU by default; override externally if desired.
        tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
        trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
    )