Unverified Commit e65b9f21 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support decode overlap schedule (#5608)

parent 4dce1cc6
......@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
from __future__ import annotations
import logging
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
......@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch
@torch.no_grad()
def event_loop_overlap_disagg_decode(self):
result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
last_batch_is_extend = False
if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, False)
last_batch_is_extend = True
else:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
# Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and not self.last_batch_is_extend:
tmp_batch, tmp_result = result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
self.last_batch_is_extend = last_batch_is_extend
def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]:
......
......@@ -2016,7 +2016,10 @@ def run_scheduler_process(
elif disaggregation_mode == DisaggregationMode.PREFILL:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
scheduler.event_loop_normal_disagg_decode()
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_decode()
else:
scheduler.event_loop_normal_disagg_decode()
except Exception:
traceback = get_exception_traceback()
......
......@@ -387,14 +387,12 @@ class ServerArgs:
# PD disaggregation
if self.disaggregation_mode == "prefill":
self.disable_cuda_graph = True
logger.warning("KV cache is forced as chunk cache for decode server")
logger.warning("Cuda graph is disabled for prefill server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for prefill server")
elif self.disaggregation_mode == "decode":
self.disable_radix_cache = True
logger.warning("Cuda graph is disabled for prefill server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for decode server")
logger.warning("KV cache is forced as chunk cache for decode server")
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment