Unverified Commit bfadb5ea authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Adjust overlap event loop (#11507)

parent 9cc1e065
...@@ -752,7 +752,6 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -752,7 +752,6 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
while True: while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
...@@ -764,6 +763,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -764,6 +763,7 @@ class SchedulerDisaggregationDecodeMixin:
prepare_mlp_sync_flag = require_mlp_sync(self.server_args) prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
batch_result = None
if batch: if batch:
# Generate fake extend output. # Generate fake extend output.
if batch.forward_mode.is_extend(): if batch.forward_mode.is_extend():
...@@ -772,25 +772,25 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -772,25 +772,25 @@ class SchedulerDisaggregationDecodeMixin:
batch.reqs, any(req.return_logprob for req in batch.reqs) batch.reqs, any(req.return_logprob for req in batch.reqs)
) )
if prepare_mlp_sync_flag: if prepare_mlp_sync_flag:
batch_, result = self._prepare_idle_batch_and_run( batch_, batch_result = self._prepare_idle_batch_and_run(
None, delay_process=True None, delay_process=True
) )
if batch_: if batch_:
self.result_queue.append((batch_.copy(), result)) self.result_queue.append((batch_.copy(), batch_result))
last_batch_in_queue = True last_batch_in_queue = True
else: else:
if prepare_mlp_sync_flag: if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch) self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch) batch_result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), batch_result))
last_batch_in_queue = True last_batch_in_queue = True
elif prepare_mlp_sync_flag: elif prepare_mlp_sync_flag:
batch, result = self._prepare_idle_batch_and_run( batch, batch_result = self._prepare_idle_batch_and_run(
None, delay_process=True None, delay_process=True
) )
if batch: if batch:
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), batch_result))
last_batch_in_queue = True last_batch_in_queue = True
# Process the results of the previous batch but skip if the last batch is extend # Process the results of the previous batch but skip if the last batch is extend
...@@ -798,6 +798,8 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -798,6 +798,8 @@ class SchedulerDisaggregationDecodeMixin:
tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch, tmp_result = self.result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
self.launch_batch_sample_if_needed(batch_result)
queue_size = ( queue_size = (
len(self.waiting_queue) len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue) + len(self.disagg_decode_transfer_queue.queue)
......
...@@ -321,8 +321,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -321,8 +321,6 @@ class SchedulerDisaggregationPrefillMixin:
self.result_queue = deque() self.result_queue = deque()
while True: while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
self.waiting_queue.extend( self.waiting_queue.extend(
...@@ -334,9 +332,11 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -334,9 +332,11 @@ class SchedulerDisaggregationPrefillMixin:
if require_mlp_sync(self.server_args): if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch) batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch self.cur_batch = batch
batch_result = None
if batch: if batch:
result = self.run_batch(batch) batch_result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), batch_result))
if self.last_batch: if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch, tmp_result = self.result_queue.popleft()
...@@ -345,6 +345,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -345,6 +345,8 @@ class SchedulerDisaggregationPrefillMixin:
if len(self.disagg_prefill_inflight_queue) > 0: if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue() self.process_disagg_prefill_inflight_queue()
self.launch_batch_sample_if_needed(batch_result)
if batch is None and len(self.disagg_prefill_inflight_queue) == 0: if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.self_check_during_idle() self.self_check_during_idle()
......
...@@ -1907,8 +1907,5 @@ class ModelWorkerBatch: ...@@ -1907,8 +1907,5 @@ class ModelWorkerBatch:
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1 hicache_consumer_index: int = -1
# Overlap scheduler related
delay_sample_launch: bool = False
# Whether this batch is prefill-only (no token generation needed) # Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False is_prefill_only: bool = False
...@@ -148,7 +148,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache ...@@ -148,7 +148,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.eagle_info import EagleDraftInput
...@@ -212,8 +212,7 @@ class GenerationBatchResult: ...@@ -212,8 +212,7 @@ class GenerationBatchResult:
# For overlap scheduling # For overlap scheduling
copy_done: Optional[torch.cuda.Event] = None copy_done: Optional[torch.cuda.Event] = None
delay_sample_launch: bool = False delay_sample_func: Optional[callable] = None
forward_batch: Optional[ForwardBatch] = None
future_indices: Optional[FutureIndices] = None future_indices: Optional[FutureIndices] = None
# FIXME(lsyin): maybe move to <BetterPlace> ? # FIXME(lsyin): maybe move to <BetterPlace> ?
...@@ -1036,17 +1035,16 @@ class Scheduler( ...@@ -1036,17 +1035,16 @@ class Scheduler(
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque() self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
while True: while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
batch_result = None
if batch: if batch:
result = self.run_batch(batch) batch_result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), batch_result))
if self.last_batch: if self.last_batch:
# Process the results of the last batch # Process the results of the last batch
...@@ -1056,6 +1054,7 @@ class Scheduler( ...@@ -1056,6 +1054,7 @@ class Scheduler(
# When the server is idle, do self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.self_check_during_idle() self.self_check_during_idle()
self.launch_batch_sample_if_needed(batch_result)
self.last_batch = batch self.last_batch = batch
@DynamicGradMode() @DynamicGradMode()
...@@ -2207,8 +2206,6 @@ class Scheduler( ...@@ -2207,8 +2206,6 @@ class Scheduler(
with self.forward_stream_ctx: with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream) self.forward_stream.wait_stream(self.default_stream)
self.future_map.resolve_future(model_worker_batch) self.future_map.resolve_future(model_worker_batch)
if batch.sampling_info.grammars is not None:
model_worker_batch.delay_sample_launch = True
batch_result = self.model_worker.forward_batch_generation( batch_result = self.model_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
...@@ -2216,7 +2213,7 @@ class Scheduler( ...@@ -2216,7 +2213,7 @@ class Scheduler(
batch_result.copy_done = torch.get_device_module( batch_result.copy_done = torch.get_device_module(
self.device self.device
).Event() ).Event()
if not model_worker_batch.delay_sample_launch: if batch_result.delay_sample_func is None:
self.future_map.store_to_map(future_indices, batch_result) self.future_map.store_to_map(future_indices, batch_result)
batch_result.copy_to_cpu() batch_result.copy_to_cpu()
else: else:
...@@ -2280,29 +2277,20 @@ class Scheduler( ...@@ -2280,29 +2277,20 @@ class Scheduler(
ret = EmbeddingBatchResult(embeddings=embeddings) ret = EmbeddingBatchResult(embeddings=embeddings)
return ret return ret
def launch_last_batch_sample_if_needed( def launch_batch_sample_if_needed(
self, self, batch_result: GenerationBatchResult
) -> Union[GenerationBatchResult, EmbeddingBatchResult]: ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
if len(self.result_queue) == 0: # TODO(lsyin): make the delayed sample a default behavior after
return # unifying the forward_batch_generation interface (related to spec V2).
if batch_result is None or batch_result.delay_sample_func is None:
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_result: GenerationBatchResult
if not tmp_result.delay_sample_launch:
self.result_queue.appendleft((tmp_batch, tmp_result))
return return
with self.forward_stream_ctx: with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream) self.forward_stream.wait_stream(self.default_stream)
tmp_result.next_token_ids = self.model_worker.model_runner.sample( _batch_result = batch_result.delay_sample_func()
tmp_result.logits_output, assert _batch_result is batch_result
tmp_result.forward_batch, self.future_map.store_to_map(batch_result.future_indices, batch_result)
) batch_result.copy_to_cpu()
future_indices = tmp_result.future_indices
self.future_map.store_to_map(future_indices, tmp_result)
tmp_result.copy_to_cpu()
self.result_queue.appendleft((tmp_batch, tmp_result))
def process_batch_result( def process_batch_result(
self, self,
......
...@@ -168,6 +168,7 @@ class TpModelWorker: ...@@ -168,6 +168,7 @@ class TpModelWorker:
)[0] )[0]
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
self.enable_overlap = not server_args.disable_overlap_schedule
self.hicache_layer_transfer_counter = None self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
...@@ -266,9 +267,18 @@ class TpModelWorker: ...@@ -266,9 +267,18 @@ class TpModelWorker:
# Skip sampling and return logits for target forward # Skip sampling and return logits for target forward
return batch_result return batch_result
if model_worker_batch.delay_sample_launch: if (
batch_result.delay_sample_launch = True self.enable_overlap
batch_result.forward_batch = forward_batch and model_worker_batch.sampling_info.grammars is not None
):
def sample_batch_func():
batch_result.next_token_ids = self.model_runner.sample(
logits_output, forward_batch
)
return batch_result
batch_result.delay_sample_func = sample_batch_func
return batch_result return batch_result
if model_worker_batch.is_prefill_only: if model_worker_batch.is_prefill_only:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment