Unverified Commit f7102fbd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix mixed chunked prefill (#1850)

parent a7a0a688
...@@ -720,9 +720,11 @@ class Scheduler: ...@@ -720,9 +720,11 @@ class Scheduler:
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode(self.enable_overlap) self.running_batch.filter_batch()
new_batch.mix_with_running(self.running_batch) if not self.running_batch.is_empty():
new_batch.decoding_reqs = self.running_batch.reqs self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None self.running_batch = None
else: else:
new_batch.decoding_reqs = None new_batch.decoding_reqs = None
......
...@@ -7,6 +7,7 @@ import random ...@@ -7,6 +7,7 @@ import random
import subprocess import subprocess
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from types import SimpleNamespace from types import SimpleNamespace
from typing import Callable, List, Optional from typing import Callable, List, Optional
...@@ -656,11 +657,12 @@ def read_output(output_lines): ...@@ -656,11 +657,12 @@ def read_output(output_lines):
time.sleep(0.1) time.sleep(0.1)
def run_mmlu_test( def run_and_check_memory_leak(
workload_func,
disable_radix_cache, disable_radix_cache,
enable_mixed_chunk=False, enable_mixed_chunk,
enable_overlap=False, enable_overlap,
chunked_prefill_size=32, chunked_prefill_size,
): ):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
if disable_radix_cache: if disable_radix_cache:
...@@ -690,21 +692,8 @@ def run_mmlu_test( ...@@ -690,21 +692,8 @@ def run_mmlu_test(
t = threading.Thread(target=read_output, args=(output_lines,)) t = threading.Thread(target=read_output, args=(output_lines,))
t.start() t.start()
# Run the eval # Run the workload
args = SimpleNamespace( workload_func(base_url, model)
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)
try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass
# Clean up everything # Clean up everything
kill_child_process(process.pid, include_self=True) kill_child_process(process.pid, include_self=True)
...@@ -727,4 +716,63 @@ def run_mmlu_test( ...@@ -727,4 +716,63 @@ def run_mmlu_test(
has_leak = True has_leak = True
assert has_new_server assert has_new_server
# assert not has_leak assert not has_leak
def run_mmlu_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):
def workload_func(base_url, model):
# Run the eval
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)
try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
def run_mulit_request_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):
def workload_func(base_url, model):
def run_one(_):
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response = requests.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
)
ret = response.json()
with ThreadPoolExecutor(2) as executor:
list(executor.map(run_one, list(range(4))))
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
...@@ -8,6 +8,7 @@ from sglang.test.test_utils import ( ...@@ -8,6 +8,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
run_bench_serving, run_bench_serving,
run_mmlu_test, run_mmlu_test,
run_mulit_request_test,
) )
...@@ -39,6 +40,12 @@ class TestChunkedPrefill(unittest.TestCase): ...@@ -39,6 +40,12 @@ class TestChunkedPrefill(unittest.TestCase):
assert res["completed"] == 10 assert res["completed"] == 10
def test_mixed_chunked_prefill_multi_requests(self):
run_mulit_request_test(
enable_mixed_chunk=True,
chunked_prefill_size=2048,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment