test_batch_invariance.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import random
import string

import pytest
import torch

from vllm import LLM, SamplingParams


def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
    # Lightweight random prompt generator to vary prompt lengths and content.
    vocab = [
        "alpha",
        "bravo",
        "charlie",
        "delta",
        "echo",
        "foxtrot",
        "golf",
        "hotel",
        "india",
        "juliet",
        "kilo",
        "lima",
        "mike",
        "november",
        "oscar",
        "papa",
        "quebec",
        "romeo",
        "sierra",
        "tango",
        "uniform",
        "victor",
        "whiskey",
        "xray",
        "yankee",
        "zulu",
    ]
    n = random.randint(min_words, max_words)
    words = random.choices(vocab, k=n)

    # Add some noise and punctuation variability
    if random.random() < 0.5:
        words[0] = words[0].capitalize()
    if random.random() < 0.2:
        words.append("".join(random.choices(string.ascii_lowercase, k=5)))
    punct = random.choice([".", "?", "!", "...", ""])
    return " ".join(words) + punct


@pytest.mark.timeout(1000)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
    """
    Ensures that the same request (the 'needle' prompt) yields identical output
    whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
    using the high-level v1 LLM() API only (no manual batching).

    Strategy:
    - Create two LLM engines with identical config except max_num_seqs: 1 vs N.
    - Compute a baseline output for the needle prompt with the bs=1 engine.
    - For many trials, generate a batch (size N) where the needle appears at a
      random position among random filler prompts using the bs=N engine.
    - Track how many trials match vs mismatch, and report totals at the end.
      The test fails if any mismatches occur, but we still dump pass/fail
      counts.

    Notes:
    - Use seeded stochastic sampling with a fixed seed to test determinism.
    - Outputs are intentionally longer and sampled at higher temperature/top_p
75
      to produce a more random-sounding phrase, yet remain deterministic by
76
77
78
      seed.
    - Keep max_tokens and max_model_len bounded for speed and memory use.
    """
79
    random.seed(12345)
80
81
82
83
84

    # Allow overrides from environment (useful for CI tuning)
    # "facebook/opt-125m" is too small, doesn't reliably test determinism
    model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
    num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
85
86
    batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
    assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
87
88

    # Keep GPU memory usage low to avoid startup allocation failures.
89
90
    gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
    max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))

    # Sampling parameters: longer outputs with a more random-sounding
    # continuation,but still deterministic due to fixed seed.
    temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
    top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
    max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))

    sampling = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        seed=20240919,
    )

106
    needle_prompt = "There once was a "
107
108
109
110
111
112
113

    llm_bs1 = None
    llm_bsN = None
    try:
        # Engine with bs=1 behavior
        llm_bs1 = LLM_with_max_seqs(
            model=model,
114
            max_num_seqs=1,
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            gpu_memory_utilization=gpu_mem_util,
            max_model_len=max_model_len,
            swap_space=swap_space_gb,
        )

        # Baseline generation for the needle prompt alone.
        baseline_out = llm_bs1.generate([needle_prompt], sampling)
        assert len(baseline_out) == 1
        assert len(baseline_out[0].outputs) >= 1
        baseline_text = baseline_out[0].outputs[0].text

        # Engine with larger batch limit (e.g., 64)
        llm_bsN = LLM_with_max_seqs(
            model=model,
129
            max_num_seqs=batch_size,
130
131
132
133
134
135
136
137
            gpu_memory_utilization=gpu_mem_util,
            max_model_len=max_model_len,
            swap_space=swap_space_gb,
        )

        mismatches = 0

        for trial in range(num_trials):
138
            # Create a batch of size `batch_size` and insert the needle at
139
140
141
142
143
144
145
            # a random index
            prompts: list[str] = []
            needle_pos = random.randint(0, batch_size - 1)
            for i in range(batch_size):
                if i == needle_pos:
                    prompts.append(needle_prompt)
                else:
146
                    prompts.append(_random_prompt())
147
148
149
150
151
152
153
154
155
156
157
158
159
160

            # Generate with the larger-batch engine
            outputs = llm_bsN.generate(prompts, sampling)
            # Find the needle output by position
            needle_output = outputs[needle_pos]
            assert needle_output.prompt == needle_prompt
            assert len(needle_output.outputs) >= 1
            text = needle_output.outputs[0].text

            if text != baseline_text:
                mismatches += 1

        passes = num_trials - mismatches
        # Dump how many passed vs failed
161
162
163
164
        print(
            f"[determinism] total={num_trials}, passed={passes}, "
            f"failed={mismatches}, batch_size={batch_size}"
        )
165
166
167
168

        if mismatches > 0:
            pytest.fail(
                f"Nondeterministic outputs detected: {mismatches} failed out "
169
170
                f"of {num_trials} trials (batch_size={batch_size})."
            )
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    finally:
        # Ensure engines are shutdown to free GPU/VRAM across test sessions
        if llm_bs1 is not None:
            with contextlib.suppress(Exception):
                llm_bs1.shutdown()
        if llm_bsN is not None:
            with contextlib.suppress(Exception):
                llm_bsN.shutdown()


def _extract_step_logprobs(request_output):
    if getattr(request_output, "outputs", None):
        inner = request_output.outputs[0]
        if hasattr(inner, "logprobs") and inner.logprobs is not None:
            t = torch.tensor(
                [
                    inner.logprobs[i][tid].logprob
                    for i, tid in enumerate(inner.token_ids)
                ],
                dtype=torch.float32,
            )
            return t

    return None


@pytest.mark.skipif(
    not torch.cuda.is_available(),
    reason="Requires CUDA to match production inference path.",
)
202
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
203
    # model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
    tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))

    # Force float32 to avoid precision-induced differences.
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,
        enforce_eager=True,  # helps reduce nondeterminism from some backends
    )

    prompts = [
        "The capital of France is",
        "The capital of Germany is",
    ]

    sp = SamplingParams(
220
        temperature=0.0,
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        top_p=1.0,
        max_tokens=8,
        # Seed shouldn't matter at temperature=0, but keeping it stable anyway.
        seed=1234,
        logprobs=5,
    )

    # BS=1: run prompts individually and collect logprobs per step.
    bs1_logprobs_per_prompt = []
    for p in prompts:
        outs = llm.generate([p], sp, use_tqdm=False)
        assert len(outs) == 1
        step_logprobs = _extract_step_logprobs(outs[0])
        if step_logprobs is None:
235
236
237
238
            pytest.skip(
                "Logits are not available on RequestOutput; "
                "enable logprobs return to run this test."
            )
239
240
        bs1_logprobs_per_prompt.append(step_logprobs)

241
    # BS=2: run prompts in a batch and collect logprobs per step for each
242
243
244
    # prompt.
    outs_batched = llm.generate(prompts, sp, use_tqdm=False)
    assert len(outs_batched) == len(prompts)
245
    bs2_logprobs_per_prompt = []
246
247
248
    for o in outs_batched:
        step_logprobs = _extract_step_logprobs(o)
        if step_logprobs is None:
249
250
251
252
            pytest.skip(
                "Logits are not available on RequestOutput; "
                "enable logprobs return to run this test."
            )
253
        bs2_logprobs_per_prompt.append(step_logprobs)
254

255
256
    # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
    for i, (logprobs_bs1, logprobs_bs2) in enumerate(
257
258
        zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)
    ):
259
        assert len(logprobs_bs1) == len(logprobs_bs2), (
260
            f"Different number of generation steps for prompt index {i}: "
261
262
            f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)"
        )
263
        for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
264
            assert a.shape == b.shape, (
265
266
                f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
            )
267
            # Bitwise exact equality.
268
269
270
271
            assert torch.equal(a, b), (
                f"Bitwise logprobs mismatch at prompt {i}, step {t} "
                f"(dtype={a.dtype}, shape={a.shape})."
            )
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294


def LLM_with_max_seqs(
    model: str,
    max_num_seqs: int,
    gpu_memory_utilization: float,
    max_model_len: int,
    swap_space: int,
) -> LLM:
    """
    Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
    using the high-level v1 LLM API, while constraining memory usage.
    """
    return LLM(
        model=model,
        max_num_seqs=max_num_seqs,
        # Constrain GPU memory pool so test can run even on busy GPUs.
        gpu_memory_utilization=gpu_memory_utilization,
        # Keep KV cache footprint small while allowing longer outputs.
        max_model_len=max_model_len,
        # Allow some CPU offload if needed.
        swap_space=swap_space,
        # Keep things lean and CI-friendly.
295
        dtype="auto",
296
297
298
        # Single-GPU by default; override externally if desired.
        tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
        trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
299
300
301
        enable_prefix_caching=False,
        # Enable for MOE models
        # enable_expert_parallel=True,
302
    )