"vscode:/vscode.git/clone" did not exist on "9c71bcb0bb825cba5cfb29f1a49871d8e4cb9117"
Unverified Commit 40d9b8ac authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Improve overlap scheduling (#5788)

parent f0365820
...@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server ...@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from __future__ import annotations from __future__ import annotations
import logging import logging
import threading
from collections import deque from collections import deque
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
...@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
self.running_batch.batch_is_full = False self.running_batch.batch_is_full = False
def process_batch_result_disagg_prefill( def process_batch_result_disagg_prefill(
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
) -> None: ) -> None:
""" """
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
...@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap: if self.enable_overlap:
# wait # wait
_, next_token_ids = self.tp_worker.resolve_batch_result(bid) _, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
else: else:
next_token_ids = result.next_token_ids.tolist() next_token_ids = result.next_token_ids.tolist()
......
...@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import copy import copy
import dataclasses import dataclasses
import logging import logging
import threading
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
...@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# This is an optimization to reduce the overhead of the prefill check. # This is an optimization to reduce the overhead of the prefill check.
batch_is_full: bool = False batch_is_full: bool = False
# Events
launch_done: Optional[threading.Event] = None
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None
...@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
), ),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
) )
def copy(self): def copy(self):
...@@ -1647,6 +1652,9 @@ class ModelWorkerBatch: ...@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
# Overlap event
launch_done: Optional[threading.Event] = None
@triton.jit @triton.jit
def write_req_to_token_pool_triton( def write_req_to_token_pool_triton(
......
...@@ -645,6 +645,7 @@ class Scheduler( ...@@ -645,6 +645,7 @@ class Scheduler(
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch) result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), result))
...@@ -656,7 +657,7 @@ class Scheduler( ...@@ -656,7 +657,7 @@ class Scheduler(
forward_mode=ForwardMode.DUMMY_FIRST, forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info, next_batch_sampling_info=self.tp_worker.cur_sampling_info,
) )
self.process_batch_result(tmp_batch, None) self.process_batch_result(tmp_batch, None, batch.launch_done)
if self.last_batch: if self.last_batch:
# Process the results of the last batch # Process the results of the last batch
...@@ -664,7 +665,10 @@ class Scheduler( ...@@ -664,7 +665,10 @@ class Scheduler(
tmp_batch.next_batch_sampling_info = ( tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None self.tp_worker.cur_sampling_info if batch else None
) )
self.process_batch_result(tmp_batch, tmp_result) # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self.process_batch_result(
tmp_batch, tmp_result, batch.launch_done if batch else None
)
elif batch is None: elif batch is None:
# When the server is idle, do self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.check_memory() self.check_memory()
...@@ -1417,14 +1421,15 @@ class Scheduler( ...@@ -1417,14 +1421,15 @@ class Scheduler(
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult], result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
): ):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result, launch_done)
elif batch.forward_mode.is_extend(): elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result, launch_done)
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
if self.enable_overlap: if self.enable_overlap:
self.tp_worker.resolve_batch_result(result.bid) self.tp_worker.resolve_last_batch_result(launch_done)
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize() self.current_stream.synchronize()
......
from __future__ import annotations from __future__ import annotations
import threading
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
...@@ -11,6 +12,7 @@ if TYPE_CHECKING: ...@@ -11,6 +12,7 @@ if TYPE_CHECKING:
EmbeddingBatchResult, EmbeddingBatchResult,
GenerationBatchResult, GenerationBatchResult,
ScheduleBatch, ScheduleBatch,
Scheduler,
) )
...@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin: ...@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
""" """
def process_batch_result_prefill( def process_batch_result_prefill(
self, self: Scheduler,
batch: ScheduleBatch, batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult], result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
): ):
skip_stream_req = None skip_stream_req = None
...@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin: ...@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
) )
if self.enable_overlap: if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) logits_output, next_token_ids = (
self.tp_worker.resolve_last_batch_result(
launch_done,
)
)
else: else:
# Move next_token_ids and logprobs to cpu # Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
...@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin: ...@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode( def process_batch_result_decode(
self, self: Scheduler,
batch: ScheduleBatch, batch: ScheduleBatch,
result: GenerationBatchResult, result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
): ):
logits_output, next_token_ids, bid = ( logits_output, next_token_ids, bid = (
result.logits_output, result.logits_output,
...@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin: ...@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap: if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
launch_done
)
next_token_logprobs = logits_output.next_token_logprobs next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none(): elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process. # spec decoding handles output logprobs inside verify process.
...@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin: ...@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
self.log_decode_stats() self.log_decode_stats()
def add_input_logprob_return_values( def add_input_logprob_return_values(
self, self: Scheduler,
i: int, i: int,
req: Req, req: Req,
output: LogitsProcessorOutput, output: LogitsProcessorOutput,
...@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin: ...@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
def add_logprob_return_values( def add_logprob_return_values(
self, self: Scheduler,
i: int, i: int,
req: Req, req: Req,
pt: int, pt: int,
...@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin: ...@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
return num_input_logprobs return num_input_logprobs
def stream_output( def stream_output(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None self: Scheduler,
reqs: List[Req],
return_logprob: bool,
skip_req: Optional[Req] = None,
): ):
"""Stream the output to detokenizer.""" """Stream the output to detokenizer."""
if self.is_generation: if self.is_generation:
...@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin: ...@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
self.stream_output_embedding(reqs) self.stream_output_embedding(reqs)
def stream_output_generation( def stream_output_generation(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None self: Scheduler,
reqs: List[Req],
return_logprob: bool,
skip_req: Optional[Req] = None,
): ):
rids = [] rids = []
finished_reasons: List[BaseFinishReason] = [] finished_reasons: List[BaseFinishReason] = []
...@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin: ...@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
) )
) )
def stream_output_embedding(self, reqs: List[Req]): def stream_output_embedding(self: Scheduler, reqs: List[Req]):
rids = [] rids = []
finished_reasons: List[BaseFinishReason] = [] finished_reasons: List[BaseFinishReason] = []
......
...@@ -170,13 +170,13 @@ class TpModelWorker: ...@@ -170,13 +170,13 @@ class TpModelWorker:
def forward_batch_generation( def forward_batch_generation(
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False, skip_sample: bool = False,
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]: ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
if launch_done:
launch_done.set() if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()
if skip_sample: if skip_sample:
next_token_ids = None next_token_ids = None
......
...@@ -132,7 +132,6 @@ class TpModelWorkerClient: ...@@ -132,7 +132,6 @@ class TpModelWorkerClient:
batch_pt += 1 batch_pt += 1
# Create event # Create event
self.launch_done = threading.Event()
copy_done = torch.get_device_module(self.device).Event() copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input # Resolve future tokens in the input
...@@ -141,7 +140,7 @@ class TpModelWorkerClient: ...@@ -141,7 +140,7 @@ class TpModelWorkerClient:
# Run forward # Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation( logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch, self.launch_done model_worker_batch
) )
# Update the future token ids map # Update the future token ids map
...@@ -168,10 +167,16 @@ class TpModelWorkerClient: ...@@ -168,10 +167,16 @@ class TpModelWorkerClient:
self.output_queue.put((copy_done, logits_output, next_token_ids)) self.output_queue.put((copy_done, logits_output, next_token_ids))
def resolve_batch_result(self, bid: int): def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done, logits_output, next_token_ids = self.output_queue.get() copy_done, logits_output, next_token_ids = self.output_queue.get()
if launch_done is not None:
launch_done.wait()
copy_done.synchronize() copy_done.synchronize()
self.launch_done.wait()
if logits_output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment