Unverified Commit 9116b289 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add a new event loop (#1677)

parent a5114b6f
...@@ -736,6 +736,10 @@ class ScheduleBatch: ...@@ -736,6 +736,10 @@ class ScheduleBatch:
self.input_ids = self.output_ids self.input_ids = self.output_ids
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.output_ids = None self.output_ids = None
if self.sampling_info.penalizer_orchestrator:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.input_ids
)
# Alloc mem # Alloc mem
bs = len(self.reqs) bs = len(self.reqs)
......
...@@ -20,6 +20,7 @@ import logging ...@@ -20,6 +20,7 @@ import logging
import os import os
import time import time
import warnings import warnings
from collections import deque
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -192,9 +193,20 @@ class Scheduler: ...@@ -192,9 +193,20 @@ class Scheduler:
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
if self.server_args.enable_overlap_schedule:
def cache_finished_req(req):
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
else:
cache_finished_req = self.tree_cache.cache_finished_req
self.cache_finished_req = cache_finished_req
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None self.running_batch: Optional[ScheduleBatch] = None
self.cur_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0 self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0 self.num_generated_tokens = 0
...@@ -279,6 +291,32 @@ class Scheduler: ...@@ -279,6 +291,32 @@ class Scheduler:
self.last_batch = batch self.last_batch = batch
@torch.inference_mode()
def event_loop_overlap(self):
result_queue = deque()
self.last_batch = None
self.running_batch = None
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
if self.last_batch:
tmp_batch, tmp_result = result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.last_batch = batch
def recv_requests(self): def recv_requests(self):
if self.tp_rank == 0: if self.tp_rank == 0:
recv_reqs = [] recv_reqs = []
...@@ -705,11 +743,6 @@ class Scheduler: ...@@ -705,11 +743,6 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation: if self.is_generation:
logits_output, next_token_ids = result logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator:
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
if batch.return_logprob: if batch.return_logprob:
# Move logprobs to cpu # Move logprobs to cpu
if logits_output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
...@@ -742,7 +775,7 @@ class Scheduler: ...@@ -742,7 +775,7 @@ class Scheduler:
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs: elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
...@@ -771,7 +804,7 @@ class Scheduler: ...@@ -771,7 +804,7 @@ class Scheduler:
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.cache_finished_req(req)
else: else:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
...@@ -779,10 +812,6 @@ class Scheduler: ...@@ -779,10 +812,6 @@ class Scheduler:
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator:
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu # Move logprobs to cpu
...@@ -796,6 +825,9 @@ class Scheduler: ...@@ -796,6 +825,9 @@ class Scheduler:
# Check finish condition # Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if self.server_args.enable_overlap_schedule and req.finished():
continue
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
...@@ -806,7 +838,7 @@ class Scheduler: ...@@ -806,7 +838,7 @@ class Scheduler:
) )
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.cache_finished_req(req)
if req.return_logprob: if req.return_logprob:
req.output_token_logprobs.append( req.output_token_logprobs.append(
...@@ -1027,7 +1059,7 @@ class Scheduler: ...@@ -1027,7 +1059,7 @@ class Scheduler:
for req in self.running_batch.reqs: for req in self.running_batch.reqs:
if req.rid == recv_req.rid and not req.finished(): if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT() req.finished_reason = FINISH_ABORT()
self.tree_cache.cache_finished_req(req) self.cache_finished_req(req)
break break
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights(self, recv_req: UpdateWeightReqInput):
...@@ -1072,7 +1104,10 @@ def run_scheduler_process( ...@@ -1072,7 +1104,10 @@ def run_scheduler_process(
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
pipe_writer.send("ready") pipe_writer.send("ready")
scheduler.event_loop_normal() if server_args.enable_overlap_schedule:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
except Exception: except Exception:
msg = get_exception_traceback() msg = get_exception_traceback()
logger.error(msg) logger.error(msg)
......
...@@ -38,12 +38,16 @@ class ChunkCache(BasePrefixCache): ...@@ -38,12 +38,16 @@ class ChunkCache(BasePrefixCache):
max_prefix_len = len(key) max_prefix_len = len(key)
return entry.value[:max_prefix_len], entry return entry.value[:max_prefix_len], entry
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_finished_req(
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
):
if token_ids is None: if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1] token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_id_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids) req.req_pool_idx, : token_id_len + free_delta
] ]
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices) self.token_to_kv_pool.free(kv_indices)
...@@ -53,10 +57,12 @@ class ChunkCache(BasePrefixCache): ...@@ -53,10 +57,12 @@ class ChunkCache(BasePrefixCache):
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None: if token_ids is None:
token_ids = req.fill_ids token_id_len = len(req.fill_ids)
else:
token_id_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids) req.req_pool_idx, :token_id_len
] ]
if req.rid not in self.entries: if req.rid not in self.entries:
......
...@@ -97,22 +97,38 @@ class RadixCache(BasePrefixCache): ...@@ -97,22 +97,38 @@ class RadixCache(BasePrefixCache):
value = [x for x in key] value = [x for x in key]
return self._insert_helper(self.root_node, key, value) return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_finished_req(
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
):
"""Cache request when it finishes.""" """Cache request when it finishes."""
if self.disable:
if token_ids is None:
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_ids_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : token_ids_len + free_delta
]
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
if token_ids is None: if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1] token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids) req.req_pool_idx, : len(token_ids)
] ]
if self.disable:
self.token_to_kv_pool.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone()) new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
if free_delta:
self.token_to_kv_pool.free(
self.req_to_token_pool.req_to_token[
req.req_pool_idx, len(token_ids) : len(token_ids) + 1
]
)
# Remove req slot release the cache lock # Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
......
...@@ -528,6 +528,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): ...@@ -528,6 +528,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
kill_child_process(pid, including_parent=False) kill_child_process(pid, including_parent=False)
return return
# print(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("ready") pipe_finish_writer.send("ready")
......
...@@ -113,6 +113,7 @@ class ServerArgs: ...@@ -113,6 +113,7 @@ class ServerArgs:
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
disable_mla: bool = False disable_mla: bool = False
disable_penalizer: bool = False disable_penalizer: bool = False
enable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
max_torch_compile_bs: int = 32 max_torch_compile_bs: int = 32
...@@ -572,6 +573,11 @@ class ServerArgs: ...@@ -572,6 +573,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable the logit penalizer (e.g., frequency and repetition penalty).", help="Disable the logit penalizer (e.g., frequency and repetition penalty).",
) )
parser.add_argument(
"--enable-overlap-schedule",
action="store_true",
help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
)
parser.add_argument( parser.add_argument(
"--enable-mixed-chunk", "--enable-mixed-chunk",
action="store_true", action="store_true",
......
...@@ -584,6 +584,7 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str): ...@@ -584,6 +584,7 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
def configure_logger(server_args, prefix: str = ""): def configure_logger(server_args, prefix: str = ""):
format = f"[%(asctime)s{prefix}] %(message)s" format = f"[%(asctime)s{prefix}] %(message)s"
# format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
logging.basicConfig( logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()), level=getattr(logging, server_args.log_level.upper()),
format=format, format=format,
......
...@@ -17,6 +17,7 @@ suites = { ...@@ -17,6 +17,7 @@ suites = {
"test_json_constrained.py", "test_json_constrained.py",
"test_large_max_new_tokens.py", "test_large_max_new_tokens.py",
"test_openai_server.py", "test_openai_server.py",
"test_overlap_schedule.py",
"test_pytorch_sampling_backend.py", "test_pytorch_sampling_backend.py",
"test_retract_decode.py", "test_retract_decode.py",
"test_server_args.py", "test_server_args.py",
......
"""
Usage:
SGLANG_IS_IN_CI=true python3 -m unittest test_overlap_schedule.TestOverlapSchedule.test_radix_attention_chunked_prefill
SGLANG_IS_IN_CI=true python3 test_overlap_schedule.py
"""
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestOverlapSchedule(unittest.TestCase):
def run_mmlu(self, disable_radix_cache, chunked_prefill_size=32):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
if disable_radix_cache:
other_args += ["--disable-radix-cache"]
other_args += ["--enable-overlap-schedule"]
model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
try:
metrics = run_eval(args)
assert metrics["score"] >= 0.65
finally:
kill_child_process(process.pid)
def test_no_radix_attention_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=True, chunked_prefill_size=32)
def test_no_radix_attention_no_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=True, chunked_prefill_size=-1)
def test_radix_attention_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False, chunked_prefill_size=32)
def test_radix_attention_no_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False, chunked_prefill_size=-1)
if __name__ == "__main__":
unittest.main()
# @unittest.skip("did not support")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment