Unverified Commit d5fa019c authored by Zhao Chen's avatar Zhao Chen Committed by GitHub
Browse files

feat: limit peak memory usage when computing logprobs (#6318)


Signed-off-by: default avatarZhao Chen <zhaochen.zju@gmail.com>
Co-authored-by: default avatar赵晨阳 <zhaochen20@outlook.com>
parent fef3a6b6
...@@ -273,6 +273,10 @@ class Envs: ...@@ -273,6 +273,10 @@ class Envs:
# Sparse Embeddings # Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None) SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
# Logits processor
SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK = EnvBool(False)
SGLANG_LOGITS_PROCESSER_CHUNK_SIZE = EnvInt(2048)
# Tool-Call behavior # Tool-Call behavior
SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF) SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF)
......
...@@ -85,6 +85,16 @@ MAX_LEN = 20000 ...@@ -85,6 +85,16 @@ MAX_LEN = 20000
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl" DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
DEFAULT_META_JSON = "baseline_meta_preview.json" DEFAULT_META_JSON = "baseline_meta_preview.json"
# Default engine configuration
DEFAULT_ENGINE_CONFIG = {
"model_path": DENSE_MODEL_NAME,
"random_seed": 42,
"skip_tokenizer_init": True,
"mem_fraction_static": 0.8,
"enable_deterministic_inference": True,
"attention_backend": "flashinfer",
}
def generate_baseline( def generate_baseline(
baseline_file=DEFAULT_BASELINE_PKL, baseline_file=DEFAULT_BASELINE_PKL,
...@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
"""Set up the test class - initialize the engine once for all tests.""" """Set up the test class - initialize the engine once for all tests."""
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...") print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
cls.engine = sgl.Engine( cls.engine = sgl.Engine(**DEFAULT_ENGINE_CONFIG)
model_path=DENSE_MODEL_NAME,
random_seed=42,
attention_backend="flashinfer",
enable_deterministic_inference=True,
skip_tokenizer_init=True,
mem_fraction_static=0.80,
)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase):
cls.engine.shutdown() cls.engine.shutdown()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@classmethod
def restart_engine_with_config(cls, **kwargs):
"""Create engine with custom configuration"""
# Safely shutdown existing engine
cls.engine.shutdown()
torch.cuda.empty_cache()
# Set chunk size
chunk_size = kwargs.pop("chunk_size", None)
if chunk_size is not None:
print(f"Setting chunk size to {chunk_size}")
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "True"
os.environ["SGLANG_LOGITS_PROCESSER_CHUNK_SIZE"] = str(chunk_size)
else:
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "False"
# Create engine with merged configuration
engine_config = {**DEFAULT_ENGINE_CONFIG, **kwargs}
cls.engine = sgl.Engine(**engine_config)
def load_test_data(self, baseline_file=None): def load_test_data(self, baseline_file=None):
"""Load test data from local baseline file. In test mode, only local baseline is supported.""" """Load test data from local baseline file. In test mode, only local baseline is supported."""
if not baseline_file: if not baseline_file:
...@@ -281,13 +304,34 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -281,13 +304,34 @@ class TestLogprobsDense(unittest.TestCase):
# Load test data with retry mechanism # Load test data with retry mechanism
records = self.load_test_data(baseline_file) records = self.load_test_data(baseline_file)
with self.subTest( # Fast configs for CI
config={ test_configs = [
{"num_samples": NUM_SAMPLES},
{"num_samples": 42, "chunk_size": 1, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 2, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 3, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 16, "max_running_requests": 128},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES, "num_samples": NUM_SAMPLES,
"logprob_sample_ratio": LOGPROB_SAMPLE_RATIO, "chunk_size": 128,
"temperature": TEMPERATURE, "max_running_requests": 128,
} },
): {"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES,
"chunk_size": 256,
"max_running_requests": 128,
},
]
# Run tests
for config in test_configs:
with self.subTest(config=config):
print(f"Testing with config: {config}")
# Sample records for this config # Sample records for this config
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records))) test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
...@@ -314,7 +358,8 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -314,7 +358,8 @@ class TestLogprobsDense(unittest.TestCase):
random.sample(range(len(test_records)), logprob_count) random.sample(range(len(test_records)), logprob_count)
) )
return_logprob_array = [ return_logprob_array = [
sample_idx in logprob_indices for sample_idx in range(len(test_records)) sample_idx in logprob_indices
for sample_idx in range(len(test_records))
] ]
# Sampling param per request # Sampling param per request
...@@ -328,6 +373,15 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -328,6 +373,15 @@ class TestLogprobsDense(unittest.TestCase):
for _ in test_records for _ in test_records
] ]
# Some configs must restart the engine to take effect
chunk_size = config.get("chunk_size", None)
max_running_requests = config.get("max_running_requests", None)
if chunk_size is not None or max_running_requests is not None:
self.restart_engine_with_config(
chunk_size=chunk_size,
max_running_requests=max_running_requests,
)
outputs = self.engine.generate( outputs = self.engine.generate(
input_ids=input_ids, input_ids=input_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
...@@ -352,7 +406,9 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -352,7 +406,9 @@ class TestLogprobsDense(unittest.TestCase):
baseline_meta = rec["meta"] baseline_meta = rec["meta"]
sglang_meta = meta_info sglang_meta = meta_info
max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta) max_diff, mean_diff = self.compare_meta(
baseline_meta, sglang_meta
)
all_max.append(max_diff) all_max.append(max_diff)
all_mean.append(mean_diff) all_mean.append(mean_diff)
logprob_returned_count += 1 logprob_returned_count += 1
...@@ -400,7 +456,9 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -400,7 +456,9 @@ class TestLogprobsDense(unittest.TestCase):
# Tolerance checks with clear error messages # Tolerance checks with clear error messages
failed_samples = [] failed_samples = []
for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)): for sample_idx, (max_diff, mean_diff) in enumerate(
zip(all_max, all_mean)
):
if max_diff > DENSE_TOLERANCE_MAX_DIFF: if max_diff > DENSE_TOLERANCE_MAX_DIFF:
failed_samples.append( failed_samples.append(
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}" f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment