Unverified Commit 40d9b8ac authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Improve overlap scheduling (#5788)

parent f0365820
......@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from __future__ import annotations
import logging
import threading
from collections import deque
from typing import TYPE_CHECKING, List, Optional
......@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
self.running_batch.batch_is_full = False
def process_batch_result_disagg_prefill(
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
......@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap:
# wait
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
else:
next_token_ids = result.next_token_ids.tolist()
......
......@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import copy
import dataclasses
import logging
import threading
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np
......@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# This is an optimization to reduce the overhead of the prefill check.
batch_is_full: bool = False
# Events
launch_done: Optional[threading.Event] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
......@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
)
def copy(self):
......@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
# Overlap event
launch_done: Optional[threading.Event] = None
@triton.jit
def write_req_to_token_pool_triton(
......
......@@ -645,6 +645,7 @@ class Scheduler(
self.cur_batch = batch
if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
......@@ -656,7 +657,7 @@ class Scheduler(
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
self.process_batch_result(tmp_batch, None, batch.launch_done)
if self.last_batch:
# Process the results of the last batch
......@@ -664,7 +665,10 @@ class Scheduler(
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self.process_batch_result(
tmp_batch, tmp_result, batch.launch_done if batch else None
)
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.check_memory()
......@@ -1417,14 +1421,15 @@ class Scheduler(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
self.process_batch_result_decode(batch, result, launch_done)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
self.process_batch_result_prefill(batch, result, launch_done)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result.bid)
self.tp_worker.resolve_last_batch_result(launch_done)
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
......
from __future__ import annotations
import threading
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......@@ -11,6 +12,7 @@ if TYPE_CHECKING:
EmbeddingBatchResult,
GenerationBatchResult,
ScheduleBatch,
Scheduler,
)
......@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
"""
def process_batch_result_prefill(
self,
self: Scheduler,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
):
skip_stream_req = None
......@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
logits_output, next_token_ids = (
self.tp_worker.resolve_last_batch_result(
launch_done,
)
)
else:
# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
......@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode(
self,
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
):
logits_output, next_token_ids, bid = (
result.logits_output,
......@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
launch_done
)
next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
......@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
self.log_decode_stats()
def add_input_logprob_return_values(
self,
self: Scheduler,
i: int,
req: Req,
output: LogitsProcessorOutput,
......@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
def add_logprob_return_values(
self,
self: Scheduler,
i: int,
req: Req,
pt: int,
......@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
return num_input_logprobs
def stream_output(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
self: Scheduler,
reqs: List[Req],
return_logprob: bool,
skip_req: Optional[Req] = None,
):
"""Stream the output to detokenizer."""
if self.is_generation:
......@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
self.stream_output_embedding(reqs)
def stream_output_generation(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
self: Scheduler,
reqs: List[Req],
return_logprob: bool,
skip_req: Optional[Req] = None,
):
rids = []
finished_reasons: List[BaseFinishReason] = []
......@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
)
)
def stream_output_embedding(self, reqs: List[Req]):
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
rids = []
finished_reasons: List[BaseFinishReason] = []
......
......@@ -170,13 +170,13 @@ class TpModelWorker:
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if launch_done:
launch_done.set()
if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()
if skip_sample:
next_token_ids = None
......
......@@ -132,7 +132,6 @@ class TpModelWorkerClient:
batch_pt += 1
# Create event
self.launch_done = threading.Event()
copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input
......@@ -141,7 +140,7 @@ class TpModelWorkerClient:
# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch, self.launch_done
model_worker_batch
)
# Update the future token ids map
......@@ -168,10 +167,16 @@ class TpModelWorkerClient:
self.output_queue.put((copy_done, logits_output, next_token_ids))
def resolve_batch_result(self, bid: int):
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done, logits_output, next_token_ids = self.output_queue.get()
if launch_done is not None:
launch_done.wait()
copy_done.synchronize()
self.launch_done.wait()
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment