Unverified Commit f7102fbd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix mixed chunked prefill (#1850)

parent a7a0a688
......@@ -720,9 +720,11 @@ class Scheduler:
# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch.filter_batch()
if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode(self.enable_overlap)
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None
else:
new_batch.decoding_reqs = None
......
......@@ -7,6 +7,7 @@ import random
import subprocess
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from types import SimpleNamespace
from typing import Callable, List, Optional
......@@ -656,11 +657,12 @@ def read_output(output_lines):
time.sleep(0.1)
def run_mmlu_test(
def run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
enable_mixed_chunk,
enable_overlap,
chunked_prefill_size,
):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
if disable_radix_cache:
......@@ -690,21 +692,8 @@ def run_mmlu_test(
t = threading.Thread(target=read_output, args=(output_lines,))
t.start()
# Run the eval
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)
try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass
# Run the workload
workload_func(base_url, model)
# Clean up everything
kill_child_process(process.pid, include_self=True)
......@@ -727,4 +716,63 @@ def run_mmlu_test(
has_leak = True
assert has_new_server
# assert not has_leak
assert not has_leak
def run_mmlu_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):
def workload_func(base_url, model):
# Run the eval
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)
try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
def run_mulit_request_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):
def workload_func(base_url, model):
def run_one(_):
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response = requests.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
)
ret = response.json()
with ThreadPoolExecutor(2) as executor:
list(executor.map(run_one, list(range(4))))
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
......@@ -8,6 +8,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
run_bench_serving,
run_mmlu_test,
run_mulit_request_test,
)
......@@ -39,6 +40,12 @@ class TestChunkedPrefill(unittest.TestCase):
assert res["completed"] == 10
def test_mixed_chunked_prefill_multi_requests(self):
run_mulit_request_test(
enable_mixed_chunk=True,
chunked_prefill_size=2048,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment