Unverified Commit d5fa019c authored by Zhao Chen's avatar Zhao Chen Committed by GitHub
Browse files

feat: limit peak memory usage when computing logprobs (#6318)


Signed-off-by: default avatarZhao Chen <zhaochen.zju@gmail.com>
Co-authored-by: default avatar赵晨阳 <zhaochen20@outlook.com>
parent fef3a6b6
......@@ -273,6 +273,10 @@ class Envs:
# Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
# Logits processor
SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK = EnvBool(False)
SGLANG_LOGITS_PROCESSER_CHUNK_SIZE = EnvInt(2048)
# Tool-Call behavior
SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF)
......
......@@ -85,6 +85,16 @@ MAX_LEN = 20000
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
DEFAULT_META_JSON = "baseline_meta_preview.json"
# Default engine configuration
DEFAULT_ENGINE_CONFIG = {
"model_path": DENSE_MODEL_NAME,
"random_seed": 42,
"skip_tokenizer_init": True,
"mem_fraction_static": 0.8,
"enable_deterministic_inference": True,
"attention_backend": "flashinfer",
}
def generate_baseline(
baseline_file=DEFAULT_BASELINE_PKL,
......@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase):
def setUpClass(cls):
"""Set up the test class - initialize the engine once for all tests."""
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
cls.engine = sgl.Engine(
model_path=DENSE_MODEL_NAME,
random_seed=42,
attention_backend="flashinfer",
enable_deterministic_inference=True,
skip_tokenizer_init=True,
mem_fraction_static=0.80,
)
cls.engine = sgl.Engine(**DEFAULT_ENGINE_CONFIG)
@classmethod
def tearDownClass(cls):
......@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase):
cls.engine.shutdown()
torch.cuda.empty_cache()
@classmethod
def restart_engine_with_config(cls, **kwargs):
"""Create engine with custom configuration"""
# Safely shutdown existing engine
cls.engine.shutdown()
torch.cuda.empty_cache()
# Set chunk size
chunk_size = kwargs.pop("chunk_size", None)
if chunk_size is not None:
print(f"Setting chunk size to {chunk_size}")
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "True"
os.environ["SGLANG_LOGITS_PROCESSER_CHUNK_SIZE"] = str(chunk_size)
else:
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "False"
# Create engine with merged configuration
engine_config = {**DEFAULT_ENGINE_CONFIG, **kwargs}
cls.engine = sgl.Engine(**engine_config)
def load_test_data(self, baseline_file=None):
"""Load test data from local baseline file. In test mode, only local baseline is supported."""
if not baseline_file:
......@@ -281,141 +304,176 @@ class TestLogprobsDense(unittest.TestCase):
# Load test data with retry mechanism
records = self.load_test_data(baseline_file)
with self.subTest(
config={
# Fast configs for CI
test_configs = [
{"num_samples": NUM_SAMPLES},
{"num_samples": 42, "chunk_size": 1, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 2, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 3, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 16, "max_running_requests": 128},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES,
"logprob_sample_ratio": LOGPROB_SAMPLE_RATIO,
"temperature": TEMPERATURE,
}
):
# Sample records for this config
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
random.shuffle(test_records)
# Calculate how many samples should return logprobs
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO)
print(
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}"
)
print(
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})"
)
all_max, all_mean = [], []
logprob_returned_count = 0
"chunk_size": 128,
"max_running_requests": 128,
},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES,
"chunk_size": 256,
"max_running_requests": 128,
},
]
# Run tests
for config in test_configs:
with self.subTest(config=config):
print(f"Testing with config: {config}")
# Sample records for this config
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
random.shuffle(test_records)
# Calculate how many samples should return logprobs
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO)
print(
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}"
)
print(
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})"
)
# Process all records at once
input_ids = [rec["ids"] for rec in test_records]
logprob_start_lens = [rec["start_pos"] for rec in test_records]
all_max, all_mean = [], []
logprob_returned_count = 0
# Determine which samples should return logprobs (randomly selected)
logprob_indices = set(
random.sample(range(len(test_records)), logprob_count)
)
return_logprob_array = [
sample_idx in logprob_indices for sample_idx in range(len(test_records))
]
# Sampling param per request
sampling_params = [
{
"temperature": TEMPERATURE,
"top_p": 1.0,
"top_k": TOP_K,
"max_new_tokens": 1,
}
for _ in test_records
]
outputs = self.engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob_array,
logprob_start_len=logprob_start_lens,
top_logprobs_num=TOP_K,
)
# Process all records at once
input_ids = [rec["ids"] for rec in test_records]
logprob_start_lens = [rec["start_pos"] for rec in test_records]
for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)):
# Only compare logprobs for samples that should have them
if sample_idx in logprob_indices:
# Safe access to meta_info and input_top_logprobs
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
self.assertIsNotNone(
input_top_logprobs,
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})",
)
baseline_meta = rec["meta"]
sglang_meta = meta_info
max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta)
all_max.append(max_diff)
all_mean.append(mean_diff)
logprob_returned_count += 1
else:
# Verify that logprobs were not returned for this sample
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
output_token_ids_logprobs = (
meta_info.get("output_token_ids_logprobs")
if meta_info
else None
)
self.assertFalse(
input_top_logprobs,
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
# Determine which samples should return logprobs (randomly selected)
logprob_indices = set(
random.sample(range(len(test_records)), logprob_count)
)
return_logprob_array = [
sample_idx in logprob_indices
for sample_idx in range(len(test_records))
]
# Sampling param per request
sampling_params = [
{
"temperature": TEMPERATURE,
"top_p": 1.0,
"top_k": TOP_K,
"max_new_tokens": 1,
}
for _ in test_records
]
# Some configs must restart the engine to take effect
chunk_size = config.get("chunk_size", None)
max_running_requests = config.get("max_running_requests", None)
if chunk_size is not None or max_running_requests is not None:
self.restart_engine_with_config(
chunk_size=chunk_size,
max_running_requests=max_running_requests,
)
max_of_max = max(all_max) if all_max else 0.0
mean_of_mean = np.mean(all_mean) if all_mean else 0.0
outputs = self.engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob_array,
logprob_start_len=logprob_start_lens,
top_logprobs_num=TOP_K,
)
print(f"max Δ={max_of_max:.6g}")
print(f"mean Δ={mean_of_mean:.6g}")
print(
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})"
)
for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)):
# Only compare logprobs for samples that should have them
if sample_idx in logprob_indices:
# Safe access to meta_info and input_top_logprobs
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
self.assertIsNotNone(
input_top_logprobs,
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})",
)
baseline_meta = rec["meta"]
sglang_meta = meta_info
max_diff, mean_diff = self.compare_meta(
baseline_meta, sglang_meta
)
all_max.append(max_diff)
all_mean.append(mean_diff)
logprob_returned_count += 1
else:
# Verify that logprobs were not returned for this sample
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
output_token_ids_logprobs = (
meta_info.get("output_token_ids_logprobs")
if meta_info
else None
)
self.assertFalse(
input_top_logprobs,
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
)
max_of_max = max(all_max) if all_max else 0.0
mean_of_mean = np.mean(all_mean) if all_mean else 0.0
print(f"max Δ={max_of_max:.6g}")
print(f"mean Δ={mean_of_mean:.6g}")
print(
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})"
)
# Verify correct number of logprobs returned
self.assertEqual(
logprob_returned_count,
logprob_count,
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
)
# Verify correct number of logprobs returned
self.assertEqual(
logprob_returned_count,
logprob_count,
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
)
# Basic validation
self.assertIsInstance(all_max, list)
self.assertIsInstance(all_mean, list)
self.assertGreater(
len(all_max),
0,
f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}",
)
# Basic validation
self.assertIsInstance(all_max, list)
self.assertIsInstance(all_mean, list)
self.assertGreater(
len(all_max),
0,
f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}",
)
# Tolerance checks with clear error messages
failed_samples = []
for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)):
if max_diff > DENSE_TOLERANCE_MAX_DIFF:
failed_samples.append(
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
)
if mean_diff > DENSE_TOLERANCE_MEAN_DIFF:
failed_samples.append(
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}"
# Tolerance checks with clear error messages
failed_samples = []
for sample_idx, (max_diff, mean_diff) in enumerate(
zip(all_max, all_mean)
):
if max_diff > DENSE_TOLERANCE_MAX_DIFF:
failed_samples.append(
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
)
if mean_diff > DENSE_TOLERANCE_MEAN_DIFF:
failed_samples.append(
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}"
)
if failed_samples:
self.fail(
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
+ "\n".join(failed_samples[:5])
)
if failed_samples:
self.fail(
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
+ "\n".join(failed_samples[:5])
)
def main():
"""Main function to handle command line arguments and run either generation or testing."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment