Unverified Commit 1519a89c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files
parent 24bc3fb0
......@@ -747,11 +747,13 @@ class SchedulerDisaggregationDecodeMixin:
@torch.no_grad()
def event_loop_overlap_disagg_decode(self: Scheduler):
result_queue = deque()
self.result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
......@@ -774,13 +776,13 @@ class SchedulerDisaggregationDecodeMixin:
None, delay_process=True
)
if batch_:
result_queue.append((batch_.copy(), result))
self.result_queue.append((batch_.copy(), result))
last_batch_in_queue = True
else:
if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
self.result_queue.append((batch.copy(), result))
if (self.last_batch is None) or (not self.last_batch_in_queue):
# Create a dummy first batch to start the pipeline for overlap schedule.
......@@ -798,12 +800,12 @@ class SchedulerDisaggregationDecodeMixin:
None, delay_process=True
)
if batch:
result_queue.append((batch.copy(), result))
self.result_queue.append((batch.copy(), result))
last_batch_in_queue = True
# Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and self.last_batch_in_queue:
tmp_batch, tmp_result = result_queue.popleft()
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
......
......@@ -321,6 +321,8 @@ class SchedulerDisaggregationPrefillMixin:
self.result_queue = deque()
while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
......@@ -368,7 +370,6 @@ class SchedulerDisaggregationPrefillMixin:
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
......@@ -379,31 +380,30 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
copy_done,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.copy_done,
)
if copy_done is not None:
copy_done.synchronize()
logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
if self.enable_overlap:
# wait
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
launch_done
)
else:
next_token_ids = result.next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
next_token_ids = result.next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0
for i, (req, next_token_id) in enumerate(
......
......@@ -37,8 +37,7 @@ class FutureMap:
return cur_future_ct
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
input_ids = model_worker_batch.input_ids
_resolve_future_token_ids(input_ids, self.token_ids_buf)
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
def update_next_future(self, future_ct: int, bs: int):
return torch.arange(
......
......@@ -886,9 +886,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# This is an optimization to reduce the overhead of the prefill check.
batch_is_full: bool = False
# Events
launch_done: Optional[threading.Event] = None
# For chunked prefill in PP
chunked_req: Optional[Req] = None
......@@ -1877,7 +1874,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
is_prefill_only=self.is_prefill_only,
)
......@@ -2018,8 +2014,8 @@ class ModelWorkerBatch:
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1
# Overlap event
launch_done: Optional[threading.Event] = None
# Overlap scheduler related
delay_sample_launch: bool = False
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
......
......@@ -25,12 +25,14 @@ from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
from typing import Deque, Dict, List, Optional, Tuple, Union
import psutil
import setproctitle
import torch
import zmq
from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext
from torch.distributed import barrier
from sglang.global_config import global_config
......@@ -112,8 +114,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ModelWorkerBatch,
MultimodalInputs,
Req,
RequestStage,
......@@ -139,15 +143,13 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
SchedulerUpdateWeightsMixin,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatchOutput,
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
......@@ -201,40 +203,48 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
@dataclass
class GenerationBatchResult:
logits_output: Optional[LogitsProcessorOutput]
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
next_token_ids: Optional[List[int]]
can_run_cuda_graph: bool
logits_output: Optional[LogitsProcessorOutput] = None
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
next_token_ids: Optional[torch.Tensor] = None
num_accepted_tokens: Optional[int] = None
can_run_cuda_graph: bool = False
# For output processing
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
@classmethod
def from_forward_batch_output(
cls,
forward_batch_output: ForwardBatchOutput,
extend_input_len_per_req: List[int],
extend_logprob_start_len_per_req: List[int],
):
# TODO(lsyin): remove this workaround logic and try to unify output classes
return cls(
logits_output=forward_batch_output.logits_output,
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
next_token_ids=forward_batch_output.next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
)
extend_input_len_per_req: Optional[List[int]] = None
extend_logprob_start_len_per_req: Optional[List[int]] = None
# For overlap scheduling
copy_done: Optional[torch.cuda.Event] = None
delay_sample_launch: bool = False
forward_batch: Optional[ForwardBatch] = None
future_map_ct: Optional[int] = None
def copy_to_cpu(self, return_logprob: bool = False):
"""Copy tensors to CPU in overlap scheduling.
Only the tensors which are needed for processing results are copied,
e.g., next_token_ids, logits outputs
"""
if return_logprob:
if self.logits_output.next_token_logits is not None:
self.logits_output.next_token_logits = (
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
)
if self.logits_output.input_token_logprobs is not None:
self.logits_output.input_token_logprobs = (
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
if self.logits_output.hidden_states is not None:
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
"cpu", non_blocking=True
)
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
self.copy_done.record()
@classmethod
def from_pp_proxy(
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
):
# TODO(lsyin): also simplify this logic
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
# TODO(lsyin): refactor PP and avoid using dict
proxy_dict = next_pp_outputs.tensors
return cls(
logits_output=logits_output,
......@@ -388,12 +398,10 @@ class Scheduler(
logger.info("Overlap scheduler is disabled for embedding models.")
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
else:
TpWorkerClass = TpModelWorker
self.tp_worker = TpWorkerClass(
from sglang.srt.managers.tp_worker import TpModelWorker
self.tp_worker = TpModelWorker(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
......@@ -525,9 +533,11 @@ class Scheduler(
self.kv_transfer_speed_gb_s: float = 0.0
self.kv_transfer_latency_ms: float = 0.0
self.sessions: Dict[str, Session] = {}
self.current_stream = torch.get_device_module(self.device).current_stream()
self.default_stream: CudaStream = torch.get_device_module(
self.device
).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
self.default_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None
# Init chunked prefill
......@@ -618,6 +628,9 @@ class Scheduler(
# Init prefill kv split size when deterministic inference is enabled with various attention backends
self.init_deterministic_inference_config()
# Init overlap
self.init_overlap()
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
......@@ -932,6 +945,32 @@ class Scheduler(
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []
def init_overlap(self):
if not self.enable_overlap:
return
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
self.device
).stream(self.forward_stream)
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
self.device
).stream(self.copy_stream)
self.future_map = FutureMap(self.max_running_requests, self.device)
self.batch_record_buf = [None] * 2
self.batch_record_ct = 0
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
# NOTE: More Reliable: record all tensors into the forward stream
# NOTE: - for all future tensors, we shall always read from future map
# - for all non-future tensors (produced only by schedule stream),
# we shall keep its reference not being release during all the forwarding pass
self.batch_record_ct = (self.batch_record_ct + 1) % 2
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)
......@@ -958,9 +997,11 @@ class Scheduler(
@DynamicGradMode()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
self.result_queue = deque()
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
......@@ -968,7 +1009,6 @@ class Scheduler(
self.cur_batch = batch
if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
......@@ -980,7 +1020,7 @@ class Scheduler(
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None, batch.launch_done)
self.process_batch_result(tmp_batch, None)
if self.last_batch:
# Process the results of the last batch
......@@ -988,10 +1028,7 @@ class Scheduler(
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self.process_batch_result(
tmp_batch, tmp_result, batch.launch_done if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.self_check_during_idle()
......@@ -2056,18 +2093,62 @@ class Scheduler(
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch()
forward_batch_output = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
if self.enable_overlap:
# FIXME: remove this assert
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
model_worker_batch = batch_or_worker_batch
self.record_batch_in_overlap(model_worker_batch)
# Sampling info will be modified during forward
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
bs = len(model_worker_batch.seq_lens)
cur_future_map_ct = self.future_map.update_ct(bs)
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
self.future_map.resolve_future(model_worker_batch)
if batch.sampling_info.grammars is not None:
model_worker_batch.delay_sample_launch = True
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
# FIXME(lsyin): maybe move this to forward_batch_generation
batch_result.copy_done = torch.get_device_module(
self.device
).Event()
if not model_worker_batch.delay_sample_launch:
self.future_map.store_to_map(
cur_future_map_ct, bs, batch_result.next_token_ids
)
batch_result.copy_to_cpu()
else:
batch_result.future_map_ct = cur_future_map_ct
# FIXME(lsyin): move this assignment elsewhere
maybe_future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
)
else:
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
maybe_future_next_token_ids = batch_result.next_token_ids
copy_done = None
if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.udpate_spec_metrics(
batch.batch_size(), forward_batch_output.num_accepted_tokens
self.update_spec_metrics(
batch.batch_size(), batch_result.num_accepted_tokens
)
# update batch's output ids
batch.output_ids = forward_batch_output.next_token_ids
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
# which can probably be replaced by future_indices later [TODO(lsyin)].
# we shall still keep the original outputs, e.g. next_token_ids
# in the GenerationBatchOutput for processing after copy_done.
batch.output_ids = maybe_future_next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
......@@ -2084,36 +2165,60 @@ class Scheduler(
else:
extend_logprob_start_len_per_req = None
return GenerationBatchResult.from_forward_batch_output(
forward_batch_output=forward_batch_output,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
batch_result.extend_input_len_per_req = extend_input_len_per_req
batch_result.extend_logprob_start_len_per_req = (
extend_logprob_start_len_per_req
)
return batch_result
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(embeddings=embeddings)
return ret
def launch_last_batch_sample_if_needed(
self,
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
if len(self.result_queue) == 0:
return
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_result: GenerationBatchResult
if not tmp_result.delay_sample_launch:
self.result_queue.appendleft((tmp_batch, tmp_result))
return
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
tmp_result.logits_output,
tmp_result.forward_batch,
)
ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs)
self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids)
tmp_result.copy_to_cpu()
self.result_queue.appendleft((tmp_batch, tmp_result))
def process_batch_result(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result, launch_done)
self.process_batch_result_decode(batch, result)
if self.enable_trace:
trace_slice_batch("decode loop", batch.reqs)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result, launch_done)
self.process_batch_result_prefill(batch, result)
if self.enable_trace:
trace_slice_batch("prefill", batch.reqs)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done)
if result.copy_done is not None:
result.copy_done.synchronize()
self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch)
......@@ -2330,7 +2435,7 @@ class Scheduler(
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
self.default_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self):
......
......@@ -69,7 +69,7 @@ class SchedulerMetricsMixin:
kv_events_config, self.attn_dp_rank
)
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
def update_spec_metrics(self, bs: int, num_accepted_tokens: int):
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens
......
......@@ -39,7 +39,6 @@ class SchedulerOutputProcessorMixin:
self: Scheduler,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
):
skip_stream_req = None
......@@ -49,29 +48,29 @@ class SchedulerOutputProcessorMixin:
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
copy_done,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.copy_done,
)
if self.enable_overlap:
logits_output, next_token_ids, _ = (
self.tp_worker.resolve_last_batch_result(launch_done)
)
else:
# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
if copy_done is not None:
copy_done.synchronize()
# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0
......@@ -204,22 +203,19 @@ class SchedulerOutputProcessorMixin:
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
):
logits_output, next_token_ids, can_run_cuda_graph = (
logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
result.logits_output,
result.next_token_ids,
result.can_run_cuda_graph,
result.copy_done,
)
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap:
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.resolve_last_batch_result(launch_done)
)
next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
if copy_done is not None:
copy_done.synchronize()
if batch.spec_algorithm.is_none():
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
......
......@@ -15,14 +15,12 @@
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
GetWeightsByNameReqInput,
......@@ -36,13 +34,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardBatchOutput,
PPProxyTensors,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
......@@ -236,9 +231,8 @@ class TpModelWorker:
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
is_verify: bool = False,
) -> ForwardBatchOutput:
) -> GenerationBatchResult:
# update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
......@@ -256,32 +250,43 @@ class TpModelWorker:
logits_output, can_run_cuda_graph = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors
)
if launch_done is not None:
launch_done.set()
skip_sample = is_verify or model_worker_batch.is_prefill_only
next_token_ids = None
if not skip_sample:
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
elif model_worker_batch.return_logprob and not is_verify:
# NOTE: Compute logprobs without full sampling
self.model_runner.compute_logprobs_only(
logits_output, model_worker_batch
)
return ForwardBatchOutput(
batch_result = GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
can_run_cuda_graph=can_run_cuda_graph,
)
if is_verify:
# Skip sampling and return logits for target forward
return batch_result
if model_worker_batch.delay_sample_launch:
batch_result.delay_sample_launch = True
batch_result.forward_batch = forward_batch
return batch_result
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
batch_result.next_token_ids = torch.zeros_like(
model_worker_batch.input_ids, dtype=torch.long
)
if model_worker_batch.return_logprob:
# NOTE: Compute logprobs without full sampling
self.model_runner.compute_logprobs_only(
logits_output, model_worker_batch
)
else:
batch_result.next_token_ids = self.model_runner.sample(
logits_output, forward_batch
)
return batch_result
else:
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
)
return ForwardBatchOutput(
pp_proxy_tensors=pp_proxy_tensors,
return GenerationBatchResult(
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
can_run_cuda_graph=can_run_cuda_graph,
)
......
......@@ -232,12 +232,8 @@ class TpModelWorkerClient:
self, model_worker_batch: ModelWorkerBatch
) -> ForwardBatchOutput:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info
sampling_info.update_penalties()
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info,
sampling_info_done=threading.Event(),
penalizer_orchestrator=None,
model_worker_batch.sampling_info = self.cur_sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
......
......@@ -902,17 +902,6 @@ class ForwardBatch:
return self.tbo_split_seq_index is not None
@dataclass
class ForwardBatchOutput:
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
# need to be more organized
logits_output: Optional[torch.Tensor] = None
next_token_ids: Optional[torch.Tensor] = None
num_accepted_tokens: Optional[int] = None
pp_proxy_tensors: Optional[PPProxyTensors] = None
can_run_cuda_graph: bool = False
def enable_num_token_non_padded(server_args):
return get_moe_expert_parallel_world_size() > 1
......
......@@ -370,6 +370,15 @@ class SamplingBatchInfo:
self.need_top_k_sampling |= other.need_top_k_sampling
self.need_min_p_sampling |= other.need_min_p_sampling
def copy_for_forward(self):
# Accumulate the penalty into a pre-allocated buffer to get rid of the dependency of `penalizer_orchestrator` later
self.update_penalties()
return dataclasses.replace(
self,
sampling_info_done=threading.Event(),
penalizer_orchestrator=None,
)
def merge_bias_tensor(
lhs: Optional[torch.Tensor],
......
......@@ -19,11 +19,11 @@ from sglang.srt.managers.schedule_batch import (
get_last_loc,
global_server_args_dict,
)
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardBatchOutput,
ForwardMode,
)
from sglang.srt.server_args import ServerArgs
......@@ -429,7 +429,7 @@ class EAGLEWorker(TpModelWorker):
def draft_model_runner(self):
return self.model_runner
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
......@@ -449,7 +449,7 @@ class EAGLEWorker(TpModelWorker):
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
return ForwardBatchOutput(
return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
num_accepted_tokens=0,
......@@ -472,7 +472,7 @@ class EAGLEWorker(TpModelWorker):
# decode is not finished
self.forward_draft_extend_after_decode(batch)
return ForwardBatchOutput(
return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=verify_output.verified_id,
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
......@@ -513,12 +513,10 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch
)
batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
logits_output, next_token_ids = (
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
batch_result.logits_output,
batch_result.next_token_ids,
)
return (
logits_output,
......@@ -822,12 +820,12 @@ class EAGLEWorker(TpModelWorker):
).cpu()
# Forward
forward_batch_output = self.target_worker.forward_batch_generation(
batch_result = self.target_worker.forward_batch_generation(
model_worker_batch, is_verify=True
)
logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.can_run_cuda_graph,
batch_result.logits_output,
batch_result.can_run_cuda_graph,
)
vocab_mask = None
......
......@@ -6,8 +6,9 @@ import torch
from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
......@@ -207,18 +208,18 @@ class NGRAMWorker:
batch_tokens.append(put_ids)
self.ngram_cache.batch_put(batch_tokens)
def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
self._prepare_for_speculative_decoding(batch)
model_worker_batch = batch.get_model_worker_batch()
num_accepted_tokens = 0
if model_worker_batch.forward_mode.is_target_verify():
forward_batch_output = self.target_worker.forward_batch_generation(
batch_result = self.target_worker.forward_batch_generation(
model_worker_batch, is_verify=True
)
logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.can_run_cuda_graph,
batch_result.logits_output,
batch_result.can_run_cuda_graph,
)
verify_input = model_worker_batch.spec_info
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
......@@ -228,16 +229,16 @@ class NGRAMWorker:
batch.forward_mode = ForwardMode.DECODE
else:
forward_batch_output = self.target_worker.forward_batch_generation(
batch_result = self.target_worker.forward_batch_generation(
model_worker_batch
)
logits_output, next_token_ids, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
forward_batch_output.can_run_cuda_graph,
batch_result.logits_output,
batch_result.next_token_ids,
batch_result.can_run_cuda_graph,
)
return ForwardBatchOutput(
return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
num_accepted_tokens=num_accepted_tokens,
......
......@@ -1160,7 +1160,7 @@ def run_bench_offline_throughput(model, other_args):
*[str(x) for x in other_args],
]
print(f"{command=}")
print(f"command={' '.join(command)}")
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment