Unverified Commit 1519a89c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files
parent 24bc3fb0
...@@ -747,11 +747,13 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -747,11 +747,13 @@ class SchedulerDisaggregationDecodeMixin:
@torch.no_grad() @torch.no_grad()
def event_loop_overlap_disagg_decode(self: Scheduler): def event_loop_overlap_disagg_decode(self: Scheduler):
result_queue = deque() self.result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
while True: while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
# polling and allocating kv cache # polling and allocating kv cache
...@@ -774,13 +776,13 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -774,13 +776,13 @@ class SchedulerDisaggregationDecodeMixin:
None, delay_process=True None, delay_process=True
) )
if batch_: if batch_:
result_queue.append((batch_.copy(), result)) self.result_queue.append((batch_.copy(), result))
last_batch_in_queue = True last_batch_in_queue = True
else: else:
if prepare_mlp_sync_flag: if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch) self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch) result = self.run_batch(batch)
result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), result))
if (self.last_batch is None) or (not self.last_batch_in_queue): if (self.last_batch is None) or (not self.last_batch_in_queue):
# Create a dummy first batch to start the pipeline for overlap schedule. # Create a dummy first batch to start the pipeline for overlap schedule.
...@@ -798,12 +800,12 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -798,12 +800,12 @@ class SchedulerDisaggregationDecodeMixin:
None, delay_process=True None, delay_process=True
) )
if batch: if batch:
result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), result))
last_batch_in_queue = True last_batch_in_queue = True
# Process the results of the previous batch but skip if the last batch is extend # Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and self.last_batch_in_queue: if self.last_batch and self.last_batch_in_queue:
tmp_batch, tmp_result = result_queue.popleft() tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = ( tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None self.tp_worker.cur_sampling_info if batch else None
) )
......
...@@ -321,6 +321,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -321,6 +321,8 @@ class SchedulerDisaggregationPrefillMixin:
self.result_queue = deque() self.result_queue = deque()
while True: while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
self.waiting_queue.extend( self.waiting_queue.extend(
...@@ -368,7 +370,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -368,7 +370,6 @@ class SchedulerDisaggregationPrefillMixin:
self: Scheduler, self: Scheduler,
batch: ScheduleBatch, batch: ScheduleBatch,
result: GenerationBatchResult, result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
) -> None: ) -> None:
""" """
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
...@@ -379,31 +380,30 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -379,31 +380,30 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids, next_token_ids,
extend_input_len_per_req, extend_input_len_per_req,
extend_logprob_start_len_per_req, extend_logprob_start_len_per_req,
copy_done,
) = ( ) = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.extend_input_len_per_req, result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req, result.extend_logprob_start_len_per_req,
result.copy_done,
) )
if copy_done is not None:
copy_done.synchronize()
logprob_pt = 0 logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
if self.enable_overlap: next_token_ids = result.next_token_ids.tolist()
# wait if batch.return_logprob:
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result( if logits_output.next_token_logprobs is not None:
launch_done logits_output.next_token_logprobs = (
) logits_output.next_token_logprobs.tolist()
else: )
next_token_ids = result.next_token_ids.tolist() if logits_output.input_token_logprobs is not None:
if batch.return_logprob: logits_output.input_token_logprobs = tuple(
if logits_output.next_token_logprobs is not None: logits_output.input_token_logprobs.tolist()
logits_output.next_token_logprobs = ( )
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0 hidden_state_offset = 0
for i, (req, next_token_id) in enumerate( for i, (req, next_token_id) in enumerate(
......
...@@ -37,8 +37,7 @@ class FutureMap: ...@@ -37,8 +37,7 @@ class FutureMap:
return cur_future_ct return cur_future_ct
def resolve_future(self, model_worker_batch: ModelWorkerBatch): def resolve_future(self, model_worker_batch: ModelWorkerBatch):
input_ids = model_worker_batch.input_ids _resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
_resolve_future_token_ids(input_ids, self.token_ids_buf)
def update_next_future(self, future_ct: int, bs: int): def update_next_future(self, future_ct: int, bs: int):
return torch.arange( return torch.arange(
......
...@@ -886,9 +886,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -886,9 +886,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# This is an optimization to reduce the overhead of the prefill check. # This is an optimization to reduce the overhead of the prefill check.
batch_is_full: bool = False batch_is_full: bool = False
# Events
launch_done: Optional[threading.Event] = None
# For chunked prefill in PP # For chunked prefill in PP
chunked_req: Optional[Req] = None chunked_req: Optional[Req] = None
...@@ -1877,7 +1874,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1877,7 +1874,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
), ),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
is_prefill_only=self.is_prefill_only, is_prefill_only=self.is_prefill_only,
) )
...@@ -2018,8 +2014,8 @@ class ModelWorkerBatch: ...@@ -2018,8 +2014,8 @@ class ModelWorkerBatch:
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1 hicache_consumer_index: int = -1
# Overlap event # Overlap scheduler related
launch_done: Optional[threading.Event] = None delay_sample_launch: bool = False
# Whether this batch is prefill-only (no token generation needed) # Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False is_prefill_only: bool = False
......
...@@ -25,12 +25,14 @@ from concurrent import futures ...@@ -25,12 +25,14 @@ from concurrent import futures
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union from typing import Deque, Dict, List, Optional, Tuple, Union
import psutil import psutil
import setproctitle import setproctitle
import torch import torch
import zmq import zmq
from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext
from torch.distributed import barrier from torch.distributed import barrier
from sglang.global_config import global_config from sglang.global_config import global_config
...@@ -112,8 +114,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -112,8 +114,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.mm_utils import init_embedding_cache from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
ModelWorkerBatch,
MultimodalInputs, MultimodalInputs,
Req, Req,
RequestStage, RequestStage,
...@@ -139,15 +143,13 @@ from sglang.srt.managers.scheduler_update_weights_mixin import ( ...@@ -139,15 +143,13 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
SchedulerUpdateWeightsMixin, SchedulerUpdateWeightsMixin,
) )
from sglang.srt.managers.session_controller import Session from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardBatchOutput, ForwardBatch,
ForwardMode, ForwardMode,
PPProxyTensors, PPProxyTensors,
) )
...@@ -201,40 +203,48 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) ...@@ -201,40 +203,48 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
@dataclass @dataclass
class GenerationBatchResult: class GenerationBatchResult:
logits_output: Optional[LogitsProcessorOutput] logits_output: Optional[LogitsProcessorOutput] = None
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
next_token_ids: Optional[List[int]] next_token_ids: Optional[torch.Tensor] = None
can_run_cuda_graph: bool num_accepted_tokens: Optional[int] = None
can_run_cuda_graph: bool = False
# For output processing # For output processing
extend_input_len_per_req: List[int] extend_input_len_per_req: Optional[List[int]] = None
extend_logprob_start_len_per_req: List[int] extend_logprob_start_len_per_req: Optional[List[int]] = None
@classmethod # For overlap scheduling
def from_forward_batch_output( copy_done: Optional[torch.cuda.Event] = None
cls, delay_sample_launch: bool = False
forward_batch_output: ForwardBatchOutput, forward_batch: Optional[ForwardBatch] = None
extend_input_len_per_req: List[int], future_map_ct: Optional[int] = None
extend_logprob_start_len_per_req: List[int],
): def copy_to_cpu(self, return_logprob: bool = False):
# TODO(lsyin): remove this workaround logic and try to unify output classes """Copy tensors to CPU in overlap scheduling.
Only the tensors which are needed for processing results are copied,
return cls( e.g., next_token_ids, logits outputs
logits_output=forward_batch_output.logits_output, """
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors, if return_logprob:
next_token_ids=forward_batch_output.next_token_ids, if self.logits_output.next_token_logits is not None:
extend_input_len_per_req=extend_input_len_per_req, self.logits_output.next_token_logits = (
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, self.logits_output.next_token_logits.to("cpu", non_blocking=True)
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph, )
) if self.logits_output.input_token_logprobs is not None:
self.logits_output.input_token_logprobs = (
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
if self.logits_output.hidden_states is not None:
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
"cpu", non_blocking=True
)
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
self.copy_done.record()
@classmethod @classmethod
def from_pp_proxy( def from_pp_proxy(
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
): ):
# TODO(lsyin): also simplify this logic # TODO(lsyin): refactor PP and avoid using dict
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
proxy_dict = next_pp_outputs.tensors proxy_dict = next_pp_outputs.tensors
return cls( return cls(
logits_output=logits_output, logits_output=logits_output,
...@@ -388,12 +398,10 @@ class Scheduler( ...@@ -388,12 +398,10 @@ class Scheduler(
logger.info("Overlap scheduler is disabled for embedding models.") logger.info("Overlap scheduler is disabled for embedding models.")
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
else:
TpWorkerClass = TpModelWorker
self.tp_worker = TpWorkerClass( from sglang.srt.managers.tp_worker import TpModelWorker
self.tp_worker = TpModelWorker(
server_args=server_args, server_args=server_args,
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
...@@ -525,9 +533,11 @@ class Scheduler( ...@@ -525,9 +533,11 @@ class Scheduler(
self.kv_transfer_speed_gb_s: float = 0.0 self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0 self.kv_transfer_latency_ms: float = 0.0
self.sessions: Dict[str, Session] = {} self.sessions: Dict[str, Session] = {}
self.current_stream = torch.get_device_module(self.device).current_stream() self.default_stream: CudaStream = torch.get_device_module(
self.device
).current_stream()
if self.device == "cpu": if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU self.default_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None self.forward_sleep_time = None
# Init chunked prefill # Init chunked prefill
...@@ -618,6 +628,9 @@ class Scheduler( ...@@ -618,6 +628,9 @@ class Scheduler(
# Init prefill kv split size when deterministic inference is enabled with various attention backends # Init prefill kv split size when deterministic inference is enabled with various attention backends
self.init_deterministic_inference_config() self.init_deterministic_inference_config()
# Init overlap
self.init_overlap()
# Init request dispatcher # Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
[ [
...@@ -932,6 +945,32 @@ class Scheduler( ...@@ -932,6 +945,32 @@ class Scheduler(
# The prefill requests that are in the middle of kv sending # The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = [] self.disagg_prefill_inflight_queue: List[Req] = []
def init_overlap(self):
if not self.enable_overlap:
return
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
self.device
).stream(self.forward_stream)
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
self.device
).stream(self.copy_stream)
self.future_map = FutureMap(self.max_running_requests, self.device)
self.batch_record_buf = [None] * 2
self.batch_record_ct = 0
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
# NOTE: More Reliable: record all tensors into the forward stream
# NOTE: - for all future tensors, we shall always read from future map
# - for all non-future tensors (produced only by schedule stream),
# we shall keep its reference not being release during all the forwarding pass
self.batch_record_ct = (self.batch_record_ct + 1) % 2
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
def init_moe_config(self): def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"): if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args) initialize_moe_config(self.server_args)
...@@ -958,9 +997,11 @@ class Scheduler( ...@@ -958,9 +997,11 @@ class Scheduler(
@DynamicGradMode() @DynamicGradMode()
def event_loop_overlap(self): def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation.""" """A scheduler loop that overlaps the CPU processing and GPU computation."""
self.result_queue = deque() self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
while True: while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
...@@ -968,7 +1009,6 @@ class Scheduler( ...@@ -968,7 +1009,6 @@ class Scheduler(
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch) result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result)) self.result_queue.append((batch.copy(), result))
...@@ -980,7 +1020,7 @@ class Scheduler( ...@@ -980,7 +1020,7 @@ class Scheduler(
forward_mode=ForwardMode.DUMMY_FIRST, forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info, next_batch_sampling_info=self.tp_worker.cur_sampling_info,
) )
self.process_batch_result(tmp_batch, None, batch.launch_done) self.process_batch_result(tmp_batch, None)
if self.last_batch: if self.last_batch:
# Process the results of the last batch # Process the results of the last batch
...@@ -988,10 +1028,7 @@ class Scheduler( ...@@ -988,10 +1028,7 @@ class Scheduler(
tmp_batch.next_batch_sampling_info = ( tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None self.tp_worker.cur_sampling_info if batch else None
) )
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's self.process_batch_result(tmp_batch, tmp_result)
self.process_batch_result(
tmp_batch, tmp_result, batch.launch_done if batch else None
)
elif batch is None: elif batch is None:
# When the server is idle, do self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.self_check_during_idle() self.self_check_during_idle()
...@@ -2056,18 +2093,62 @@ class Scheduler( ...@@ -2056,18 +2093,62 @@ class Scheduler(
# FIXME(lsyin): remove this if and finally unify the abstraction # FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch() batch_or_worker_batch = batch.get_model_worker_batch()
forward_batch_output = self.model_worker.forward_batch_generation( if self.enable_overlap:
batch_or_worker_batch # FIXME: remove this assert
) assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
model_worker_batch = batch_or_worker_batch
self.record_batch_in_overlap(model_worker_batch)
# Sampling info will be modified during forward
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
bs = len(model_worker_batch.seq_lens)
cur_future_map_ct = self.future_map.update_ct(bs)
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
self.future_map.resolve_future(model_worker_batch)
if batch.sampling_info.grammars is not None:
model_worker_batch.delay_sample_launch = True
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
# FIXME(lsyin): maybe move this to forward_batch_generation
batch_result.copy_done = torch.get_device_module(
self.device
).Event()
if not model_worker_batch.delay_sample_launch:
self.future_map.store_to_map(
cur_future_map_ct, bs, batch_result.next_token_ids
)
batch_result.copy_to_cpu()
else:
batch_result.future_map_ct = cur_future_map_ct
# FIXME(lsyin): move this assignment elsewhere
maybe_future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
)
else:
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
maybe_future_next_token_ids = batch_result.next_token_ids
copy_done = None
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.udpate_spec_metrics( self.update_spec_metrics(
batch.batch_size(), forward_batch_output.num_accepted_tokens batch.batch_size(), batch_result.num_accepted_tokens
) )
# update batch's output ids # NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
batch.output_ids = forward_batch_output.next_token_ids # which can probably be replaced by future_indices later [TODO(lsyin)].
# we shall still keep the original outputs, e.g. next_token_ids
# in the GenerationBatchOutput for processing after copy_done.
batch.output_ids = maybe_future_next_token_ids
# These 2 values are needed for processing the output, but the values can be # These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that # modified by overlap schedule. So we have to copy them here so that
...@@ -2084,36 +2165,60 @@ class Scheduler( ...@@ -2084,36 +2165,60 @@ class Scheduler(
else: else:
extend_logprob_start_len_per_req = None extend_logprob_start_len_per_req = None
return GenerationBatchResult.from_forward_batch_output( batch_result.extend_input_len_per_req = extend_input_len_per_req
forward_batch_output=forward_batch_output, batch_result.extend_logprob_start_len_per_req = (
extend_input_len_per_req=extend_input_len_per_req, extend_logprob_start_len_per_req
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
) )
return batch_result
else: # embedding or reward model else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(embeddings=embeddings) ret = EmbeddingBatchResult(embeddings=embeddings)
return ret return ret
def launch_last_batch_sample_if_needed(
self,
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
if len(self.result_queue) == 0:
return
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_result: GenerationBatchResult
if not tmp_result.delay_sample_launch:
self.result_queue.appendleft((tmp_batch, tmp_result))
return
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
tmp_result.logits_output,
tmp_result.forward_batch,
)
ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs)
self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids)
tmp_result.copy_to_cpu()
self.result_queue.appendleft((tmp_batch, tmp_result))
def process_batch_result( def process_batch_result(
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult], result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
): ):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result, launch_done) self.process_batch_result_decode(batch, result)
if self.enable_trace: if self.enable_trace:
trace_slice_batch("decode loop", batch.reqs) trace_slice_batch("decode loop", batch.reqs)
elif batch.forward_mode.is_extend(): elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result, launch_done) self.process_batch_result_prefill(batch, result)
if self.enable_trace: if self.enable_trace:
trace_slice_batch("prefill", batch.reqs) trace_slice_batch("prefill", batch.reqs)
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
if self.enable_overlap: if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done) if result.copy_done is not None:
result.copy_done.synchronize()
self.set_next_batch_sampling_info_done(batch) self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch) self.set_next_batch_sampling_info_done(batch)
...@@ -2330,7 +2435,7 @@ class Scheduler( ...@@ -2330,7 +2435,7 @@ class Scheduler(
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None: if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize() self.default_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self): def watchdog_thread(self):
......
...@@ -69,7 +69,7 @@ class SchedulerMetricsMixin: ...@@ -69,7 +69,7 @@ class SchedulerMetricsMixin:
kv_events_config, self.attn_dp_rank kv_events_config, self.attn_dp_rank
) )
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int): def update_spec_metrics(self, bs: int, num_accepted_tokens: int):
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens self.num_generated_tokens += num_accepted_tokens
......
...@@ -39,7 +39,6 @@ class SchedulerOutputProcessorMixin: ...@@ -39,7 +39,6 @@ class SchedulerOutputProcessorMixin:
self: Scheduler, self: Scheduler,
batch: ScheduleBatch, batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult], result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
): ):
skip_stream_req = None skip_stream_req = None
...@@ -49,29 +48,29 @@ class SchedulerOutputProcessorMixin: ...@@ -49,29 +48,29 @@ class SchedulerOutputProcessorMixin:
next_token_ids, next_token_ids,
extend_input_len_per_req, extend_input_len_per_req,
extend_logprob_start_len_per_req, extend_logprob_start_len_per_req,
copy_done,
) = ( ) = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.extend_input_len_per_req, result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req, result.extend_logprob_start_len_per_req,
result.copy_done,
) )
if self.enable_overlap: if copy_done is not None:
logits_output, next_token_ids, _ = ( copy_done.synchronize()
self.tp_worker.resolve_last_batch_result(launch_done)
) # Move next_token_ids and logprobs to cpu
else: next_token_ids = next_token_ids.tolist()
# Move next_token_ids and logprobs to cpu if batch.return_logprob:
next_token_ids = next_token_ids.tolist() if logits_output.next_token_logprobs is not None:
if batch.return_logprob: logits_output.next_token_logprobs = (
if logits_output.next_token_logprobs is not None: logits_output.next_token_logprobs.tolist()
logits_output.next_token_logprobs = ( )
logits_output.next_token_logprobs.tolist() if logits_output.input_token_logprobs is not None:
) logits_output.input_token_logprobs = tuple(
if logits_output.input_token_logprobs is not None: logits_output.input_token_logprobs.tolist()
logits_output.input_token_logprobs = tuple( )
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0 hidden_state_offset = 0
...@@ -204,22 +203,19 @@ class SchedulerOutputProcessorMixin: ...@@ -204,22 +203,19 @@ class SchedulerOutputProcessorMixin:
self: Scheduler, self: Scheduler,
batch: ScheduleBatch, batch: ScheduleBatch,
result: GenerationBatchResult, result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
): ):
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.can_run_cuda_graph, result.can_run_cuda_graph,
result.copy_done,
) )
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap: if copy_done is not None:
logits_output, next_token_ids, can_run_cuda_graph = ( copy_done.synchronize()
self.tp_worker.resolve_last_batch_result(launch_done)
) if batch.spec_algorithm.is_none():
next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
if batch.return_logprob: if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist() next_token_logprobs = logits_output.next_token_logprobs.tolist()
......
...@@ -15,14 +15,12 @@ ...@@ -15,14 +15,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import threading from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
...@@ -36,13 +34,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -36,13 +34,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
ForwardBatch,
ForwardBatchOutput,
PPProxyTensors,
)
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
...@@ -236,9 +231,8 @@ class TpModelWorker: ...@@ -236,9 +231,8 @@ class TpModelWorker:
def forward_batch_generation( def forward_batch_generation(
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
is_verify: bool = False, is_verify: bool = False,
) -> ForwardBatchOutput: ) -> GenerationBatchResult:
# update the consumer index of hicache to the running batch # update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
...@@ -256,32 +250,43 @@ class TpModelWorker: ...@@ -256,32 +250,43 @@ class TpModelWorker:
logits_output, can_run_cuda_graph = self.model_runner.forward( logits_output, can_run_cuda_graph = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors forward_batch, pp_proxy_tensors=pp_proxy_tensors
) )
if launch_done is not None: batch_result = GenerationBatchResult(
launch_done.set()
skip_sample = is_verify or model_worker_batch.is_prefill_only
next_token_ids = None
if not skip_sample:
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
elif model_worker_batch.return_logprob and not is_verify:
# NOTE: Compute logprobs without full sampling
self.model_runner.compute_logprobs_only(
logits_output, model_worker_batch
)
return ForwardBatchOutput(
logits_output=logits_output, logits_output=logits_output,
next_token_ids=next_token_ids,
can_run_cuda_graph=can_run_cuda_graph, can_run_cuda_graph=can_run_cuda_graph,
) )
if is_verify:
# Skip sampling and return logits for target forward
return batch_result
if model_worker_batch.delay_sample_launch:
batch_result.delay_sample_launch = True
batch_result.forward_batch = forward_batch
return batch_result
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
batch_result.next_token_ids = torch.zeros_like(
model_worker_batch.input_ids, dtype=torch.long
)
if model_worker_batch.return_logprob:
# NOTE: Compute logprobs without full sampling
self.model_runner.compute_logprobs_only(
logits_output, model_worker_batch
)
else:
batch_result.next_token_ids = self.model_runner.sample(
logits_output, forward_batch
)
return batch_result
else: else:
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward( pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
forward_batch, forward_batch,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
return ForwardBatchOutput( return GenerationBatchResult(
pp_proxy_tensors=pp_proxy_tensors, pp_hidden_states_proxy_tensors=pp_proxy_tensors,
can_run_cuda_graph=can_run_cuda_graph, can_run_cuda_graph=can_run_cuda_graph,
) )
......
...@@ -232,12 +232,8 @@ class TpModelWorkerClient: ...@@ -232,12 +232,8 @@ class TpModelWorkerClient:
self, model_worker_batch: ModelWorkerBatch self, model_worker_batch: ModelWorkerBatch
) -> ForwardBatchOutput: ) -> ForwardBatchOutput:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info model_worker_batch.sampling_info = self.cur_sampling_info = (
sampling_info.update_penalties() model_worker_batch.sampling_info.copy_for_forward()
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info,
sampling_info_done=threading.Event(),
penalizer_orchestrator=None,
) )
# A cuda stream sync here to avoid the cuda illegal memory access error. # A cuda stream sync here to avoid the cuda illegal memory access error.
......
...@@ -902,17 +902,6 @@ class ForwardBatch: ...@@ -902,17 +902,6 @@ class ForwardBatch:
return self.tbo_split_seq_index is not None return self.tbo_split_seq_index is not None
@dataclass
class ForwardBatchOutput:
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
# need to be more organized
logits_output: Optional[torch.Tensor] = None
next_token_ids: Optional[torch.Tensor] = None
num_accepted_tokens: Optional[int] = None
pp_proxy_tensors: Optional[PPProxyTensors] = None
can_run_cuda_graph: bool = False
def enable_num_token_non_padded(server_args): def enable_num_token_non_padded(server_args):
return get_moe_expert_parallel_world_size() > 1 return get_moe_expert_parallel_world_size() > 1
......
...@@ -370,6 +370,15 @@ class SamplingBatchInfo: ...@@ -370,6 +370,15 @@ class SamplingBatchInfo:
self.need_top_k_sampling |= other.need_top_k_sampling self.need_top_k_sampling |= other.need_top_k_sampling
self.need_min_p_sampling |= other.need_min_p_sampling self.need_min_p_sampling |= other.need_min_p_sampling
def copy_for_forward(self):
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
self.update_penalties()
return dataclasses.replace(
self,
sampling_info_done=threading.Event(),
penalizer_orchestrator=None,
)
def merge_bias_tensor( def merge_bias_tensor(
lhs: Optional[torch.Tensor], lhs: Optional[torch.Tensor],
......
...@@ -19,11 +19,11 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -19,11 +19,11 @@ from sglang.srt.managers.schedule_batch import (
get_last_loc, get_last_loc,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
ForwardBatchOutput,
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -429,7 +429,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -429,7 +429,7 @@ class EAGLEWorker(TpModelWorker):
def draft_model_runner(self): def draft_model_runner(self):
return self.model_runner return self.model_runner
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
"""Run speculative decoding forward. """Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that NOTE: Many states of batch is modified as you go through. It is not guaranteed that
...@@ -449,7 +449,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -449,7 +449,7 @@ class EAGLEWorker(TpModelWorker):
self.forward_draft_extend( self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
) )
return ForwardBatchOutput( return GenerationBatchResult(
logits_output=logits_output, logits_output=logits_output,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
num_accepted_tokens=0, num_accepted_tokens=0,
...@@ -472,7 +472,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -472,7 +472,7 @@ class EAGLEWorker(TpModelWorker):
# decode is not finished # decode is not finished
self.forward_draft_extend_after_decode(batch) self.forward_draft_extend_after_decode(batch)
return ForwardBatchOutput( return GenerationBatchResult(
logits_output=logits_output, logits_output=logits_output,
next_token_ids=verify_output.verified_id, next_token_ids=verify_output.verified_id,
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu), num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
...@@ -513,12 +513,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -513,12 +513,10 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model. # We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
forward_batch_output = self.target_worker.forward_batch_generation( batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
model_worker_batch
)
logits_output, next_token_ids = ( logits_output, next_token_ids = (
forward_batch_output.logits_output, batch_result.logits_output,
forward_batch_output.next_token_ids, batch_result.next_token_ids,
) )
return ( return (
logits_output, logits_output,
...@@ -822,12 +820,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -822,12 +820,12 @@ class EAGLEWorker(TpModelWorker):
).cpu() ).cpu()
# Forward # Forward
forward_batch_output = self.target_worker.forward_batch_generation( batch_result = self.target_worker.forward_batch_generation(
model_worker_batch, is_verify=True model_worker_batch, is_verify=True
) )
logits_output, can_run_cuda_graph = ( logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output, batch_result.logits_output,
forward_batch_output.can_run_cuda_graph, batch_result.can_run_cuda_graph,
) )
vocab_mask = None vocab_mask = None
......
...@@ -6,8 +6,9 @@ import torch ...@@ -6,8 +6,9 @@ import torch
from sgl_kernel.speculative import reconstruct_indices_from_tree_mask from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.ngram_utils import NgramVerifyInput
...@@ -207,18 +208,18 @@ class NGRAMWorker: ...@@ -207,18 +208,18 @@ class NGRAMWorker:
batch_tokens.append(put_ids) batch_tokens.append(put_ids)
self.ngram_cache.batch_put(batch_tokens) self.ngram_cache.batch_put(batch_tokens)
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput: def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
self._prepare_for_speculative_decoding(batch) self._prepare_for_speculative_decoding(batch)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
num_accepted_tokens = 0 num_accepted_tokens = 0
if model_worker_batch.forward_mode.is_target_verify(): if model_worker_batch.forward_mode.is_target_verify():
forward_batch_output = self.target_worker.forward_batch_generation( batch_result = self.target_worker.forward_batch_generation(
model_worker_batch, is_verify=True model_worker_batch, is_verify=True
) )
logits_output, can_run_cuda_graph = ( logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output, batch_result.logits_output,
forward_batch_output.can_run_cuda_graph, batch_result.can_run_cuda_graph,
) )
verify_input = model_worker_batch.spec_info verify_input = model_worker_batch.spec_info
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify( logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
...@@ -228,16 +229,16 @@ class NGRAMWorker: ...@@ -228,16 +229,16 @@ class NGRAMWorker:
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
else: else:
forward_batch_output = self.target_worker.forward_batch_generation( batch_result = self.target_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
forward_batch_output.logits_output, batch_result.logits_output,
forward_batch_output.next_token_ids, batch_result.next_token_ids,
forward_batch_output.can_run_cuda_graph, batch_result.can_run_cuda_graph,
) )
return ForwardBatchOutput( return GenerationBatchResult(
logits_output=logits_output, logits_output=logits_output,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
num_accepted_tokens=num_accepted_tokens, num_accepted_tokens=num_accepted_tokens,
......
...@@ -1160,7 +1160,7 @@ def run_bench_offline_throughput(model, other_args): ...@@ -1160,7 +1160,7 @@ def run_bench_offline_throughput(model, other_args):
*[str(x) for x in other_args], *[str(x) for x in other_args],
] ]
print(f"{command=}") print(f"command={' '.join(command)}")
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment