"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "a92ae3688afad51245d135a3f361fb7e20364d6d"
Unverified Commit d5fa019c authored by Zhao Chen's avatar Zhao Chen Committed by GitHub
Browse files

feat: limit peak memory usage when computing logprobs (#6318)


Signed-off-by: default avatarZhao Chen <zhaochen.zju@gmail.com>
Co-authored-by: default avatar赵晨阳 <zhaochen20@outlook.com>
parent fef3a6b6
...@@ -273,6 +273,10 @@ class Envs: ...@@ -273,6 +273,10 @@ class Envs:
# Sparse Embeddings # Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None) SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
# Logits processor
SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK = EnvBool(False)
SGLANG_LOGITS_PROCESSER_CHUNK_SIZE = EnvInt(2048)
# Tool-Call behavior # Tool-Call behavior
SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF) SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF)
......
...@@ -85,6 +85,16 @@ MAX_LEN = 20000 ...@@ -85,6 +85,16 @@ MAX_LEN = 20000
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl" DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
DEFAULT_META_JSON = "baseline_meta_preview.json" DEFAULT_META_JSON = "baseline_meta_preview.json"
# Default engine configuration
DEFAULT_ENGINE_CONFIG = {
"model_path": DENSE_MODEL_NAME,
"random_seed": 42,
"skip_tokenizer_init": True,
"mem_fraction_static": 0.8,
"enable_deterministic_inference": True,
"attention_backend": "flashinfer",
}
def generate_baseline( def generate_baseline(
baseline_file=DEFAULT_BASELINE_PKL, baseline_file=DEFAULT_BASELINE_PKL,
...@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
"""Set up the test class - initialize the engine once for all tests.""" """Set up the test class - initialize the engine once for all tests."""
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...") print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
cls.engine = sgl.Engine( cls.engine = sgl.Engine(**DEFAULT_ENGINE_CONFIG)
model_path=DENSE_MODEL_NAME,
random_seed=42,
attention_backend="flashinfer",
enable_deterministic_inference=True,
skip_tokenizer_init=True,
mem_fraction_static=0.80,
)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase):
cls.engine.shutdown() cls.engine.shutdown()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@classmethod
def restart_engine_with_config(cls, **kwargs):
"""Create engine with custom configuration"""
# Safely shutdown existing engine
cls.engine.shutdown()
torch.cuda.empty_cache()
# Set chunk size
chunk_size = kwargs.pop("chunk_size", None)
if chunk_size is not None:
print(f"Setting chunk size to {chunk_size}")
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "True"
os.environ["SGLANG_LOGITS_PROCESSER_CHUNK_SIZE"] = str(chunk_size)
else:
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "False"
# Create engine with merged configuration
engine_config = {**DEFAULT_ENGINE_CONFIG, **kwargs}
cls.engine = sgl.Engine(**engine_config)
def load_test_data(self, baseline_file=None): def load_test_data(self, baseline_file=None):
"""Load test data from local baseline file. In test mode, only local baseline is supported.""" """Load test data from local baseline file. In test mode, only local baseline is supported."""
if not baseline_file: if not baseline_file:
...@@ -281,141 +304,176 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -281,141 +304,176 @@ class TestLogprobsDense(unittest.TestCase):
# Load test data with retry mechanism # Load test data with retry mechanism
records = self.load_test_data(baseline_file) records = self.load_test_data(baseline_file)
with self.subTest( # Fast configs for CI
config={ test_configs = [
{"num_samples": NUM_SAMPLES},
{"num_samples": 42, "chunk_size": 1, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 2, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 3, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 16, "max_running_requests": 128},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES, "num_samples": NUM_SAMPLES,
"logprob_sample_ratio": LOGPROB_SAMPLE_RATIO, "chunk_size": 128,
"temperature": TEMPERATURE, "max_running_requests": 128,
} },
): {"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 32},
# Sample records for this config {
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records))) "num_samples": NUM_SAMPLES,
random.shuffle(test_records) "chunk_size": 256,
"max_running_requests": 128,
# Calculate how many samples should return logprobs },
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO) ]
print(
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}" # Run tests
) for config in test_configs:
print( with self.subTest(config=config):
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})" print(f"Testing with config: {config}")
)
# Sample records for this config
all_max, all_mean = [], [] test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
logprob_returned_count = 0 random.shuffle(test_records)
# Calculate how many samples should return logprobs
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO)
print(
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}"
)
print(
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})"
)
# Process all records at once all_max, all_mean = [], []
input_ids = [rec["ids"] for rec in test_records] logprob_returned_count = 0
logprob_start_lens = [rec["start_pos"] for rec in test_records]
# Determine which samples should return logprobs (randomly selected) # Process all records at once
logprob_indices = set( input_ids = [rec["ids"] for rec in test_records]
random.sample(range(len(test_records)), logprob_count) logprob_start_lens = [rec["start_pos"] for rec in test_records]
)
return_logprob_array = [
sample_idx in logprob_indices for sample_idx in range(len(test_records))
]
# Sampling param per request
sampling_params = [
{
"temperature": TEMPERATURE,
"top_p": 1.0,
"top_k": TOP_K,
"max_new_tokens": 1,
}
for _ in test_records
]
outputs = self.engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob_array,
logprob_start_len=logprob_start_lens,
top_logprobs_num=TOP_K,
)
for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)): # Determine which samples should return logprobs (randomly selected)
# Only compare logprobs for samples that should have them logprob_indices = set(
if sample_idx in logprob_indices: random.sample(range(len(test_records)), logprob_count)
# Safe access to meta_info and input_top_logprobs )
meta_info = output.get("meta_info") return_logprob_array = [
input_top_logprobs = ( sample_idx in logprob_indices
meta_info.get("input_top_logprobs") if meta_info else None for sample_idx in range(len(test_records))
) ]
self.assertIsNotNone( # Sampling param per request
input_top_logprobs, sampling_params = [
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})", {
) "temperature": TEMPERATURE,
baseline_meta = rec["meta"] "top_p": 1.0,
sglang_meta = meta_info "top_k": TOP_K,
"max_new_tokens": 1,
max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta) }
all_max.append(max_diff) for _ in test_records
all_mean.append(mean_diff) ]
logprob_returned_count += 1
else: # Some configs must restart the engine to take effect
# Verify that logprobs were not returned for this sample chunk_size = config.get("chunk_size", None)
meta_info = output.get("meta_info") max_running_requests = config.get("max_running_requests", None)
input_top_logprobs = ( if chunk_size is not None or max_running_requests is not None:
meta_info.get("input_top_logprobs") if meta_info else None self.restart_engine_with_config(
) chunk_size=chunk_size,
output_token_ids_logprobs = ( max_running_requests=max_running_requests,
meta_info.get("output_token_ids_logprobs")
if meta_info
else None
)
self.assertFalse(
input_top_logprobs,
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
) )
max_of_max = max(all_max) if all_max else 0.0 outputs = self.engine.generate(
mean_of_mean = np.mean(all_mean) if all_mean else 0.0 input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob_array,
logprob_start_len=logprob_start_lens,
top_logprobs_num=TOP_K,
)
print(f"max Δ={max_of_max:.6g}") for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)):
print(f"mean Δ={mean_of_mean:.6g}") # Only compare logprobs for samples that should have them
print( if sample_idx in logprob_indices:
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})" # Safe access to meta_info and input_top_logprobs
) meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
self.assertIsNotNone(
input_top_logprobs,
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})",
)
baseline_meta = rec["meta"]
sglang_meta = meta_info
max_diff, mean_diff = self.compare_meta(
baseline_meta, sglang_meta
)
all_max.append(max_diff)
all_mean.append(mean_diff)
logprob_returned_count += 1
else:
# Verify that logprobs were not returned for this sample
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
output_token_ids_logprobs = (
meta_info.get("output_token_ids_logprobs")
if meta_info
else None
)
self.assertFalse(
input_top_logprobs,
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
)
max_of_max = max(all_max) if all_max else 0.0
mean_of_mean = np.mean(all_mean) if all_mean else 0.0
print(f"max Δ={max_of_max:.6g}")
print(f"mean Δ={mean_of_mean:.6g}")
print(
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})"
)
# Verify correct number of logprobs returned # Verify correct number of logprobs returned
self.assertEqual( self.assertEqual(
logprob_returned_count, logprob_returned_count,
logprob_count, logprob_count,
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}", f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
) )
# Basic validation # Basic validation
self.assertIsInstance(all_max, list) self.assertIsInstance(all_max, list)
self.assertIsInstance(all_mean, list) self.assertIsInstance(all_mean, list)
self.assertGreater( self.assertGreater(
len(all_max), len(all_max),
0, 0,
f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}", f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}",
) )
# Tolerance checks with clear error messages # Tolerance checks with clear error messages
failed_samples = [] failed_samples = []
for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)): for sample_idx, (max_diff, mean_diff) in enumerate(
if max_diff > DENSE_TOLERANCE_MAX_DIFF: zip(all_max, all_mean)
failed_samples.append( ):
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}" if max_diff > DENSE_TOLERANCE_MAX_DIFF:
) failed_samples.append(
if mean_diff > DENSE_TOLERANCE_MEAN_DIFF: f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
failed_samples.append( )
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}" if mean_diff > DENSE_TOLERANCE_MEAN_DIFF:
failed_samples.append(
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}"
)
if failed_samples:
self.fail(
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
+ "\n".join(failed_samples[:5])
) )
if failed_samples:
self.fail(
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
+ "\n".join(failed_samples[:5])
)
def main(): def main():
"""Main function to handle command line arguments and run either generation or testing.""" """Main function to handle command line arguments and run either generation or testing."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment