Unverified Commit d5fa019c authored by Zhao Chen's avatar Zhao Chen Committed by GitHub
Browse files

feat: limit peak memory usage when computing logprobs (#6318)


Signed-off-by: default avatarZhao Chen <zhaochen.zju@gmail.com>
Co-authored-by: default avatar赵晨阳 <zhaochen20@outlook.com>
parent fef3a6b6
......@@ -273,6 +273,10 @@ class Envs:
# Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
# Logits processor
SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK = EnvBool(False)
SGLANG_LOGITS_PROCESSER_CHUNK_SIZE = EnvInt(2048)
# Tool-Call behavior
SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF)
......
......@@ -85,6 +85,16 @@ MAX_LEN = 20000
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
DEFAULT_META_JSON = "baseline_meta_preview.json"
# Default engine configuration
DEFAULT_ENGINE_CONFIG = {
"model_path": DENSE_MODEL_NAME,
"random_seed": 42,
"skip_tokenizer_init": True,
"mem_fraction_static": 0.8,
"enable_deterministic_inference": True,
"attention_backend": "flashinfer",
}
def generate_baseline(
baseline_file=DEFAULT_BASELINE_PKL,
......@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase):
def setUpClass(cls):
"""Set up the test class - initialize the engine once for all tests."""
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
cls.engine = sgl.Engine(
model_path=DENSE_MODEL_NAME,
random_seed=42,
attention_backend="flashinfer",
enable_deterministic_inference=True,
skip_tokenizer_init=True,
mem_fraction_static=0.80,
)
cls.engine = sgl.Engine(**DEFAULT_ENGINE_CONFIG)
@classmethod
def tearDownClass(cls):
......@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase):
cls.engine.shutdown()
torch.cuda.empty_cache()
@classmethod
def restart_engine_with_config(cls, **kwargs):
"""Create engine with custom configuration"""
# Safely shutdown existing engine
cls.engine.shutdown()
torch.cuda.empty_cache()
# Set chunk size
chunk_size = kwargs.pop("chunk_size", None)
if chunk_size is not None:
print(f"Setting chunk size to {chunk_size}")
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "True"
os.environ["SGLANG_LOGITS_PROCESSER_CHUNK_SIZE"] = str(chunk_size)
else:
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "False"
# Create engine with merged configuration
engine_config = {**DEFAULT_ENGINE_CONFIG, **kwargs}
cls.engine = sgl.Engine(**engine_config)
def load_test_data(self, baseline_file=None):
"""Load test data from local baseline file. In test mode, only local baseline is supported."""
if not baseline_file:
......@@ -281,13 +304,34 @@ class TestLogprobsDense(unittest.TestCase):
# Load test data with retry mechanism
records = self.load_test_data(baseline_file)
with self.subTest(
config={
# Fast configs for CI
test_configs = [
{"num_samples": NUM_SAMPLES},
{"num_samples": 42, "chunk_size": 1, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 2, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 3, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 16, "max_running_requests": 128},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES,
"logprob_sample_ratio": LOGPROB_SAMPLE_RATIO,
"temperature": TEMPERATURE,
}
):
"chunk_size": 128,
"max_running_requests": 128,
},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES,
"chunk_size": 256,
"max_running_requests": 128,
},
]
# Run tests
for config in test_configs:
with self.subTest(config=config):
print(f"Testing with config: {config}")
# Sample records for this config
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
......@@ -314,7 +358,8 @@ class TestLogprobsDense(unittest.TestCase):
random.sample(range(len(test_records)), logprob_count)
)
return_logprob_array = [
sample_idx in logprob_indices for sample_idx in range(len(test_records))
sample_idx in logprob_indices
for sample_idx in range(len(test_records))
]
# Sampling param per request
......@@ -328,6 +373,15 @@ class TestLogprobsDense(unittest.TestCase):
for _ in test_records
]
# Some configs must restart the engine to take effect
chunk_size = config.get("chunk_size", None)
max_running_requests = config.get("max_running_requests", None)
if chunk_size is not None or max_running_requests is not None:
self.restart_engine_with_config(
chunk_size=chunk_size,
max_running_requests=max_running_requests,
)
outputs = self.engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
......@@ -352,7 +406,9 @@ class TestLogprobsDense(unittest.TestCase):
baseline_meta = rec["meta"]
sglang_meta = meta_info
max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta)
max_diff, mean_diff = self.compare_meta(
baseline_meta, sglang_meta
)
all_max.append(max_diff)
all_mean.append(mean_diff)
logprob_returned_count += 1
......@@ -400,7 +456,9 @@ class TestLogprobsDense(unittest.TestCase):
# Tolerance checks with clear error messages
failed_samples = []
for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)):
for sample_idx, (max_diff, mean_diff) in enumerate(
zip(all_max, all_mean)
):
if max_diff > DENSE_TOLERANCE_MAX_DIFF:
failed_samples.append(
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment