Unverified Commit 53fb229f authored by Night's avatar Night Committed by GitHub
Browse files

[logprobs] Enable local deterministic logrprobs testing with strict threshold (#10994)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 4fff1ec1
......@@ -86,7 +86,6 @@ suites = {
TestFile("test_input_embeddings.py", 38),
TestFile("test_io_struct.py", 8),
TestFile("test_jinja_template_utils.py", 1),
TestFile("test_logprobs.py", 55),
TestFile("test_mamba_unittest.py", 4),
TestFile("test_metrics.py", 32),
TestFile("test_metrics_utils.py", 1),
......
import io
"""
Logprobs Accuracy Test for SGLang
======================
With deterministic/batch invariant kernels, we can ensure that SGLang produces exactly the same
logprobs results for identical inputs. However, logprobs are highly sensitive to GPU hardware,
kernels, torch versions, and other factors, so we cannot maintain a unified logprobs baseline
across different machines.
This test is designed to be run locally by contributors to verify logprobs accuracy
before making changes to related code.
When submitting changes that affect logprobs computation, please:
1. Generate baseline
2. Run test
3. Submit results
We really appreciate your effort and contribution to SGLang!
======================
What does this test do?
This test fetches 1000 samples from the ShareGPT dataset, generates logprobs for each sample,
and saves them as a baseline. Then, by running the test mode, it validates the accuracy of
logprobs by comparing them against the baseline.
This test ensures that:
- the boundary of log probs requests are correct, eg, the index for tokens that required log probs are strictly followed
- logprobs remain invariant between test runs, and also before and after your code changes;
======================
Usage
Step 1: Generate Baseline (Before Code Changes)
```bash
python test/srt/test_logprobs.py gen
```
Step 2: Test Against Baseline (After Code Changes)
```bash
python test/srt/test_logprobs.py test
```
This tests your changes against the locally generated baseline from Step 1.
The test passes if the maximum and mean differences are within the tolerance thresholds.
======================
"""
import argparse
import json
import os
import pickle
import random
import time
import unittest
import numpy as np
import requests
import torch
from transformers import AutoTokenizer
import sglang as sgl
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
write_github_step_summary,
)
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Dense model configuration
# Configuration
DENSE_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
if torch.version.hip is not None:
print("Running on AMD ROCm GPU")
DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/yushengsu/logprobs/resolve/main/sglang_baseline_2000_amd.pkl"
DENSE_TOLERANCE_MAX_DIFF = 1.4
DENSE_TOLERANCE_MEAN_DIFF = 0.1
elif torch.version.cuda is not None:
SHAREGPT_URL = (
"https://huggingface.co/datasets/anon8231489123/"
"ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
)
# Hardware-specific configuration
if torch.version.cuda is not None:
print("Running on NVIDIA CUDA GPU")
DENSE_INPUT_PKL_URL = "https://huggingface.co/datasets/font-info/logprobs/resolve/main/sglang_baseline_2000.pkl"
DENSE_TOLERANCE_MAX_DIFF = 1.5
DENSE_TOLERANCE_MEAN_DIFF = 0.1
DENSE_TOLERANCE_MAX_DIFF = 1e-5
DENSE_TOLERANCE_MEAN_DIFF = 1e-5
else:
print("No GPU backend (CPU only)")
raise ValueError("No GPU backend (CPU only)")
# Common configuration
TOP_K = 20
MAX_RETRIES = 3
RETRY_DELAY = 2
NUM_SAMPLES = 1000
LOGPROB_SAMPLE_RATIO = 0.5
TEMPERATURE = 1.0
MAX_LEN = 20000
# Default output files
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
DEFAULT_META_JSON = "baseline_meta_preview.json"
def generate_baseline(
baseline_file=DEFAULT_BASELINE_PKL,
meta_file=DEFAULT_META_JSON,
num_samples=NUM_SAMPLES,
):
"""Generate a local baseline for logprobs testing.
Args:
baseline_file: Path to save the baseline pickle file
meta_file: Path to save the metadata preview JSON file
num_samples: Number of samples to generate
"""
print(f"SGLang version: {sgl.__version__}")
print("Downloading ShareGPT dataset...")
# Download ShareGPT dataset
try:
response = requests.get(SHAREGPT_URL, timeout=30)
response.raise_for_status()
data = response.json()
print(f"Dataset size: {len(data)}")
except requests.exceptions.RequestException as e:
raise Exception(f"Failed to download ShareGPT dataset: {e}") from e
# Filter and prepare texts
texts = []
for s in data:
if "conversations" in s and len(s["conversations"]) > 0:
try:
text = s["conversations"][0]["value"]
if isinstance(text, str) and len(text) <= MAX_LEN and len(text) >= 5500:
texts.append(text)
if len(texts) >= num_samples * 40: # Get more samples for filtering
break
except (KeyError, IndexError, TypeError) as e:
print(f"Warning: Skipping invalid conversation data: {e}")
continue
if not texts:
raise ValueError("No valid texts found in the dataset")
print(f"Loading tokenizer for {DENSE_MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(DENSE_MODEL_NAME, use_fast=True)
rng = np.random.default_rng(42)
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
engine = sgl.Engine(
model_path=DENSE_MODEL_NAME,
attention_backend="flashinfer",
enable_deterministic_inference=True,
random_seed=42,
skip_tokenizer_init=True,
mem_fraction_static=0.8,
max_running_requests=1,
)
records = []
prompt_lengths = []
try:
for i, text in enumerate(texts):
if len(records) >= num_samples:
break
try:
ids = tokenizer.encode(text, add_special_tokens=False)
if len(ids) < 5:
continue
start_pos = int(rng.integers(0, max(1, len(ids) - 3)))
outputs = engine.generate(
input_ids=[ids],
sampling_params={
"temperature": 1.0,
"top_p": 1.0,
"top_k": TOP_K,
"max_new_tokens": 1,
},
return_logprob=True,
logprob_start_len=start_pos,
top_logprobs_num=TOP_K,
)
meta = outputs[0]["meta_info"]
records.append(
dict(id=i, text=text, ids=ids, start_pos=start_pos, meta=meta)
)
prompt_lengths.append(len(ids))
if (i + 1) % 50 == 0:
print(f"Processed {len(records)}/{num_samples} samples")
except Exception as e:
print(f"Warning: Failed to process sample {i}: {e}")
continue
if not records:
raise RuntimeError(
"Failed to generate any baseline records. Please check the warnings above for errors."
)
# Save baseline files
with open(baseline_file, "wb") as f:
pickle.dump(records, f)
with open(meta_file, "w", encoding="utf-8") as f:
json.dump(records[:2], f, ensure_ascii=False, indent=2)
print(f"✅ Saved {len(records)} samples to {baseline_file}")
print(f"✅ Meta preview saved to {meta_file}")
if prompt_lengths:
avg_prompt_length = sum(prompt_lengths) / len(prompt_lengths)
print(f"📊 Average prompt length: {avg_prompt_length:.2f} tokens")
finally:
engine.shutdown()
torch.cuda.empty_cache()
class TestLogprobsDense(unittest.TestCase):
......@@ -48,6 +216,8 @@ class TestLogprobsDense(unittest.TestCase):
cls.engine = sgl.Engine(
model_path=DENSE_MODEL_NAME,
random_seed=42,
attention_backend="flashinfer",
enable_deterministic_inference=True,
skip_tokenizer_init=True,
mem_fraction_static=0.80,
)
......@@ -58,31 +228,24 @@ class TestLogprobsDense(unittest.TestCase):
cls.engine.shutdown()
torch.cuda.empty_cache()
def load_test_data(self):
"""Load test data from Hugging Face dataset with retry mechanism."""
print(f"Loading data from {DENSE_INPUT_PKL_URL}...")
for attempt in range(MAX_RETRIES):
try:
response = requests.get(DENSE_INPUT_PKL_URL, timeout=30)
response.raise_for_status()
with io.BytesIO(response.content) as f:
records = pickle.load(f)
def load_test_data(self, baseline_file=None):
"""Load test data from local baseline file. In test mode, only local baseline is supported."""
if not baseline_file:
raise ValueError("baseline_file is required in test mode")
if not records:
raise ValueError("Empty dataset")
print(f"Successfully loaded {len(records)} records")
return records
if not os.path.exists(baseline_file):
raise FileNotFoundError(
f"Baseline file not found: {baseline_file}. Please run 'gen' mode first to generate the baseline."
)
except Exception as e:
print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}")
if attempt == MAX_RETRIES - 1:
raise Exception(
f"Failed to load data after {MAX_RETRIES} attempts: {e}"
)
time.sleep(RETRY_DELAY)
print(f"Loading local baseline from {baseline_file}...")
try:
with open(baseline_file, "rb") as f:
records = pickle.load(f)
print(f"Successfully loaded {len(records)} records from local baseline")
return records
except (IOError, pickle.PickleError) as e:
raise Exception(f"Failed to load local baseline: {e}") from e
def compare_meta(self, baseline_meta, sglang_meta):
"""Compare metadata between two outputs and return max and mean differences."""
......@@ -102,19 +265,21 @@ class TestLogprobsDense(unittest.TestCase):
common_tokens = baseline_token_map.keys() & sglang_token_map.keys()
self.assertGreaterEqual(
len(common_tokens),
TOP_K / 2,
TOP_K,
f"there are only {len(common_tokens)} common topk tokens that matches",
)
for token_id in common_tokens:
diffs.append(
abs(baseline_token_map[token_id] - sglang_token_map[token_id])
)
if not diffs:
return 0.0, 0.0
return max(diffs), float(np.mean(diffs))
def test_logprobs_comparison(self):
def test_logprobs_comparison(self, baseline_file=None):
"""Test the logprobs comparison functionality with different parameter combinations."""
# Load test data with retry mechanism
records = self.load_test_data()
records = self.load_test_data(baseline_file)
with self.subTest(
config={
......@@ -224,15 +389,6 @@ class TestLogprobsDense(unittest.TestCase):
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
)
# Write results to GitHub summary
summary_content = f"""
- **Configuration**: {{"num_samples": {NUM_SAMPLES}, "logprob_sample_ratio": {LOGPROB_SAMPLE_RATIO}, "temperature": {TEMPERATURE}}}
- **Max of max Δ**: {max_of_max:.6g}
- **Mean of mean Δ**: {mean_of_mean:.6g}
- **Status**: {'✅ Passed' if max_of_max <= DENSE_TOLERANCE_MAX_DIFF and mean_of_mean <= DENSE_TOLERANCE_MEAN_DIFF else '❌ Failed'}
"""
write_github_step_summary(summary_content)
# Basic validation
self.assertIsInstance(all_max, list)
self.assertIsInstance(all_mean, list)
......@@ -261,5 +417,52 @@ class TestLogprobsDense(unittest.TestCase):
)
def main():
"""Main function to handle command line arguments and run either generation or testing."""
parser = argparse.ArgumentParser(
description="SGLang Logprobs Test and Baseline Generation"
)
parser.add_argument(
"mode",
choices=["gen", "test"],
help="Mode to run: 'gen' to generate baseline, 'test' to run tests",
)
args = parser.parse_args()
if args.mode == "gen":
print("🚀 Generating baseline...")
generate_baseline()
print(f"\n✅ Baseline generation complete!")
print(f"📁 Baseline saved to: {DEFAULT_BASELINE_PKL}")
print(f"📁 Metadata preview saved to: {DEFAULT_META_JSON}")
print(f"\n💡 Next steps:")
print(f" 1. Make your code changes")
print(f" 2. Run: python {__file__} test")
elif args.mode == "test":
print("🧪 Running logprobs test...")
if not os.path.exists(DEFAULT_BASELINE_PKL):
print(f"❌ Baseline file not found: {DEFAULT_BASELINE_PKL}")
print(f"💡 Generate baseline first by running:")
print(f" python {__file__} gen")
print(f" This will download ShareGPT data and generate a local baseline.")
return 1
# Set environment variable for testing
os.environ["RETURN_ORIGINAL_LOGPROB"] = "True"
# Create test instance and run
test_instance = TestLogprobsDense()
test_instance.setUpClass()
try:
test_instance.test_logprobs_comparison(baseline_file=DEFAULT_BASELINE_PKL)
print("\n✅ Test completed successfully!")
finally:
test_instance.tearDownClass()
return 0
if __name__ == "__main__":
unittest.main()
exit(main())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment