test_logprobs.py 17.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
Logprobs Accuracy Test for SGLang

======================
With deterministic/batch invariant kernels, we can ensure that SGLang produces exactly the same
logprobs results for identical inputs. However, logprobs are highly sensitive to GPU hardware,
kernels, torch versions, and other factors, so we cannot maintain a unified logprobs baseline
across different machines.

This test is designed to be run locally by contributors to verify logprobs accuracy
before making changes to related code.
When submitting changes that affect logprobs computation, please:
1. Generate baseline
2. Run test
3. Submit results

We really appreciate your effort and contribution to SGLang!

======================
What does this test do?
This test fetches 1000 samples from the ShareGPT dataset, generates logprobs for each sample,
and saves them as a baseline. Then, by running the test mode, it validates the accuracy of
logprobs by comparing them against the baseline.

This test ensures that:
- the boundary of log probs requests are correct, eg, the index for tokens that required log probs are strictly followed
- logprobs remain invariant between test runs, and also before and after your code changes;

======================
Usage

Step 1: Generate Baseline (Before Code Changes)
```bash
python test/srt/test_logprobs.py gen
```

Step 2: Test Against Baseline (After Code Changes)
```bash
python test/srt/test_logprobs.py test
```
This tests your changes against the locally generated baseline from Step 1.
The test passes if the maximum and mean differences are within the tolerance thresholds.
======================
"""

import argparse
import json
48
49
50
51
52
53
54
55
import os
import pickle
import random
import unittest

import numpy as np
import requests
import torch
56
from transformers import AutoTokenizer
57
58

import sglang as sgl
59
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
60

61
# Configuration
62
DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
63
64
65
66
67
68
69
SHAREGPT_URL = (
    "https://huggingface.co/datasets/anon8231489123/"
    "ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
)

# Hardware-specific configuration
if torch.version.cuda is not None:
70
    print("Running on NVIDIA CUDA GPU")
71
72
    DENSE_TOLERANCE_MAX_DIFF = 1e-5
    DENSE_TOLERANCE_MEAN_DIFF = 1e-5
73
74
else:
    print("No GPU backend (CPU only)")
75
    raise ValueError("No GPU backend (CPU only)")
76
77
78
79
80
81

# Common configuration
TOP_K = 20
NUM_SAMPLES = 1000
LOGPROB_SAMPLE_RATIO = 0.5
TEMPERATURE = 1.0
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
MAX_LEN = 20000

# Default output files
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
DEFAULT_META_JSON = "baseline_meta_preview.json"


def generate_baseline(
    baseline_file=DEFAULT_BASELINE_PKL,
    meta_file=DEFAULT_META_JSON,
    num_samples=NUM_SAMPLES,
):
    """Generate a local baseline for logprobs testing.

    Args:
        baseline_file: Path to save the baseline pickle file
        meta_file: Path to save the metadata preview JSON file
        num_samples: Number of samples to generate
    """
    print(f"SGLang version: {sgl.__version__}")
    print("Downloading ShareGPT dataset...")

    # Download ShareGPT dataset
    try:
        response = requests.get(SHAREGPT_URL, timeout=30)
        response.raise_for_status()
        data = response.json()
        print(f"Dataset size: {len(data)}")
    except requests.exceptions.RequestException as e:
        raise Exception(f"Failed to download ShareGPT dataset: {e}") from e

    # Filter and prepare texts
    texts = []
    for s in data:
        if "conversations" in s and len(s["conversations"]) > 0:
            try:
                text = s["conversations"][0]["value"]
                if isinstance(text, str) and len(text) <= MAX_LEN and len(text) >= 5500:
                    texts.append(text)
                    if len(texts) >= num_samples * 40:  # Get more samples for filtering
                        break
            except (KeyError, IndexError, TypeError) as e:
                print(f"Warning: Skipping invalid conversation data: {e}")
                continue

    if not texts:
        raise ValueError("No valid texts found in the dataset")

    print(f"Loading tokenizer for {DENSE_MODEL_NAME}...")
    tokenizer = AutoTokenizer.from_pretrained(DENSE_MODEL_NAME, use_fast=True)

    rng = np.random.default_rng(42)

    print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
    engine = sgl.Engine(
        model_path=DENSE_MODEL_NAME,
        attention_backend="flashinfer",
        enable_deterministic_inference=True,
        random_seed=42,
        skip_tokenizer_init=True,
        mem_fraction_static=0.8,
        max_running_requests=1,
    )

    records = []
    prompt_lengths = []

    try:
        for i, text in enumerate(texts):
            if len(records) >= num_samples:
                break

            try:
                ids = tokenizer.encode(text, add_special_tokens=False)
                if len(ids) < 5:
                    continue

                start_pos = int(rng.integers(0, max(1, len(ids) - 3)))

                outputs = engine.generate(
                    input_ids=[ids],
                    sampling_params={
                        "temperature": 1.0,
                        "top_p": 1.0,
                        "top_k": TOP_K,
                        "max_new_tokens": 1,
                    },
                    return_logprob=True,
                    logprob_start_len=start_pos,
                    top_logprobs_num=TOP_K,
                )
                meta = outputs[0]["meta_info"]

                records.append(
                    dict(id=i, text=text, ids=ids, start_pos=start_pos, meta=meta)
                )
                prompt_lengths.append(len(ids))

                if (i + 1) % 50 == 0:
                    print(f"Processed {len(records)}/{num_samples} samples")

            except Exception as e:
                print(f"Warning: Failed to process sample {i}: {e}")
                continue

        if not records:
            raise RuntimeError(
                "Failed to generate any baseline records. Please check the warnings above for errors."
            )

        # Save baseline files
        with open(baseline_file, "wb") as f:
            pickle.dump(records, f)
        with open(meta_file, "w", encoding="utf-8") as f:
            json.dump(records[:2], f, ensure_ascii=False, indent=2)

        print(f"✅ Saved {len(records)} samples to {baseline_file}")
        print(f"✅ Meta preview saved to {meta_file}")

        if prompt_lengths:
            avg_prompt_length = sum(prompt_lengths) / len(prompt_lengths)
            print(f"📊 Average prompt length: {avg_prompt_length:.2f} tokens")

    finally:
        engine.shutdown()
        torch.cuda.empty_cache()
208
209
210
211
212
213
214
215
216
217
218


class TestLogprobsDense(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        """Set up the test class - initialize the engine once for all tests."""
        print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
        cls.engine = sgl.Engine(
            model_path=DENSE_MODEL_NAME,
            random_seed=42,
219
220
            attention_backend="flashinfer",
            enable_deterministic_inference=True,
221
            skip_tokenizer_init=True,
222
            mem_fraction_static=0.80,
223
224
225
226
227
228
229
230
        )

    @classmethod
    def tearDownClass(cls):
        """Clean up after all tests - shutdown the engine."""
        cls.engine.shutdown()
        torch.cuda.empty_cache()

231
232
233
234
    def load_test_data(self, baseline_file=None):
        """Load test data from local baseline file. In test mode, only local baseline is supported."""
        if not baseline_file:
            raise ValueError("baseline_file is required in test mode")
235

236
237
238
239
        if not os.path.exists(baseline_file):
            raise FileNotFoundError(
                f"Baseline file not found: {baseline_file}. Please run 'gen' mode first to generate the baseline."
            )
240

241
242
243
244
245
246
247
248
        print(f"Loading local baseline from {baseline_file}...")
        try:
            with open(baseline_file, "rb") as f:
                records = pickle.load(f)
            print(f"Successfully loaded {len(records)} records from local baseline")
            return records
        except (IOError, pickle.PickleError) as e:
            raise Exception(f"Failed to load local baseline: {e}") from e
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

    def compare_meta(self, baseline_meta, sglang_meta):
        """Compare metadata between two outputs and return max and mean differences."""
        diffs = []
        for key in ["input_top_logprobs", "output_top_logprobs"]:
            baseline_logprobs, sglang_logprobs = baseline_meta[key], sglang_meta[key]
            self.assertEqual(
                len(baseline_logprobs),
                len(sglang_logprobs),
                f"Length of {key} is not equal, sglang did not return the correct number of log probs(should be top 20)",
            )
            for baseline_entry, sglang_entry in zip(baseline_logprobs, sglang_logprobs):
                if not baseline_entry or not sglang_entry:
                    continue
                baseline_token_map = {tid: lp for lp, tid, _ in baseline_entry}
                sglang_token_map = {tid: lp for lp, tid, _ in sglang_entry}
                common_tokens = baseline_token_map.keys() & sglang_token_map.keys()
                self.assertGreaterEqual(
                    len(common_tokens),
268
                    TOP_K,
269
270
271
272
273
274
                    f"there are only {len(common_tokens)} common topk tokens that matches",
                )
                for token_id in common_tokens:
                    diffs.append(
                        abs(baseline_token_map[token_id] - sglang_token_map[token_id])
                    )
275
276
        if not diffs:
            return 0.0, 0.0
277
278
        return max(diffs), float(np.mean(diffs))

279
    def test_logprobs_comparison(self, baseline_file=None):
280
281
        """Test the logprobs comparison functionality with different parameter combinations."""
        # Load test data with retry mechanism
282
        records = self.load_test_data(baseline_file)
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

        with self.subTest(
            config={
                "num_samples": NUM_SAMPLES,
                "logprob_sample_ratio": LOGPROB_SAMPLE_RATIO,
                "temperature": TEMPERATURE,
            }
        ):

            # Sample records for this config
            test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
            random.shuffle(test_records)

            # Calculate how many samples should return logprobs
            logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO)
            print(
                f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}"
            )
            print(
                f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})"
            )

            all_max, all_mean = [], []
            logprob_returned_count = 0

            # Process all records at once
            input_ids = [rec["ids"] for rec in test_records]
            logprob_start_lens = [rec["start_pos"] for rec in test_records]

            # Determine which samples should return logprobs (randomly selected)
            logprob_indices = set(
                random.sample(range(len(test_records)), logprob_count)
            )
            return_logprob_array = [
                sample_idx in logprob_indices for sample_idx in range(len(test_records))
            ]

            # Sampling param per request
            sampling_params = [
                {
                    "temperature": TEMPERATURE,
                    "top_p": 1.0,
                    "top_k": TOP_K,
                    "max_new_tokens": 1,
                }
                for _ in test_records
            ]

            outputs = self.engine.generate(
                input_ids=input_ids,
                sampling_params=sampling_params,
                return_logprob=return_logprob_array,
                logprob_start_len=logprob_start_lens,
                top_logprobs_num=TOP_K,
            )

            for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)):
                # Only compare logprobs for samples that should have them
                if sample_idx in logprob_indices:
                    # Safe access to meta_info and input_top_logprobs
                    meta_info = output.get("meta_info")
                    input_top_logprobs = (
                        meta_info.get("input_top_logprobs") if meta_info else None
                    )

                    self.assertIsNotNone(
                        input_top_logprobs,
                        f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})",
                    )
                    baseline_meta = rec["meta"]
                    sglang_meta = meta_info

                    max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta)
                    all_max.append(max_diff)
                    all_mean.append(mean_diff)
                    logprob_returned_count += 1
                else:
                    # Verify that logprobs were not returned for this sample
                    meta_info = output.get("meta_info")
                    input_top_logprobs = (
                        meta_info.get("input_top_logprobs") if meta_info else None
                    )
                    output_token_ids_logprobs = (
                        meta_info.get("output_token_ids_logprobs")
                        if meta_info
                        else None
                    )

                    self.assertFalse(
                        input_top_logprobs,
                        f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
                    )

            max_of_max = max(all_max) if all_max else 0.0
            mean_of_mean = np.mean(all_mean) if all_mean else 0.0

            print(f"max Δ={max_of_max:.6g}")
            print(f"mean Δ={mean_of_mean:.6g}")
            print(
                f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})"
            )

            # Verify correct number of logprobs returned
            self.assertEqual(
                logprob_returned_count,
                logprob_count,
                f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
            )

            # Basic validation
            self.assertIsInstance(all_max, list)
            self.assertIsInstance(all_mean, list)
            self.assertGreater(
                len(all_max),
                0,
                f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}",
            )

            # Tolerance checks with clear error messages
            failed_samples = []
            for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)):
                if max_diff > DENSE_TOLERANCE_MAX_DIFF:
                    failed_samples.append(
                        f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
                    )
                if mean_diff > DENSE_TOLERANCE_MEAN_DIFF:
                    failed_samples.append(
                        f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}"
                    )

            if failed_samples:
                self.fail(
                    f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
                    + "\n".join(failed_samples[:5])
                )


420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
def main():
    """Main function to handle command line arguments and run either generation or testing."""
    parser = argparse.ArgumentParser(
        description="SGLang Logprobs Test and Baseline Generation"
    )
    parser.add_argument(
        "mode",
        choices=["gen", "test"],
        help="Mode to run: 'gen' to generate baseline, 'test' to run tests",
    )

    args = parser.parse_args()

    if args.mode == "gen":
        print("🚀 Generating baseline...")
        generate_baseline()
        print(f"\n✅ Baseline generation complete!")
        print(f"📁 Baseline saved to: {DEFAULT_BASELINE_PKL}")
        print(f"📁 Metadata preview saved to: {DEFAULT_META_JSON}")
        print(f"\n💡 Next steps:")
        print(f"   1. Make your code changes")
        print(f"   2. Run: python {__file__} test")

    elif args.mode == "test":
        print("🧪 Running logprobs test...")
        if not os.path.exists(DEFAULT_BASELINE_PKL):
            print(f"❌ Baseline file not found: {DEFAULT_BASELINE_PKL}")
            print(f"💡 Generate baseline first by running:")
            print(f"   python {__file__} gen")
            print(f"   This will download ShareGPT data and generate a local baseline.")
            return 1

        # Set environment variable for testing
        os.environ["RETURN_ORIGINAL_LOGPROB"] = "True"

        # Create test instance and run
        test_instance = TestLogprobsDense()
        test_instance.setUpClass()
        try:
            test_instance.test_logprobs_comparison(baseline_file=DEFAULT_BASELINE_PKL)
            print("\n✅ Test completed successfully!")
        finally:
            test_instance.tearDownClass()

    return 0


467
if __name__ == "__main__":
468
    exit(main())