"vscode:/vscode.git/clone" did not exist on "738b75f41e5d3229e5ccda52d76e1297d7b0520d"
Unverified Commit 458611de authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Unify forward output datastructure (#11124)

parent 3511b370
...@@ -22,6 +22,7 @@ from typing import List, Optional, Set, Union ...@@ -22,6 +22,7 @@ from typing import List, Optional, Set, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.environ import envs
from sglang.srt.hf_transformers_utils import ( from sglang.srt.hf_transformers_utils import (
get_config, get_config,
get_context_length, get_context_length,
...@@ -31,7 +32,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -31,7 +32,7 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip, retry from sglang.srt.utils import is_hip, retry
from sglang.utils import is_in_ci from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -237,7 +238,7 @@ class ModelConfig: ...@@ -237,7 +238,7 @@ class ModelConfig:
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config." f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
) )
if ( if (
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.get()
or is_in_ci() # FIXME: fix this special case or is_in_ci() # FIXME: fix this special case
): ):
logger.warning(msg) logger.warning(msg)
......
...@@ -689,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -689,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
self.running_mbs = [ self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
] ]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None pp_outputs: Optional[PPProxyTensors] = None
# Either success or failed # Either success or failed
...@@ -761,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -761,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
# send the outputs to the next step # send the outputs to the next step
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
if self.cur_batch: if self.cur_batch:
next_token_ids, bids[mb_id] = ( next_token_ids = result.next_token_ids
result.next_token_ids,
result.bid,
)
pp_outputs = PPProxyTensors( pp_outputs = PPProxyTensors(
{ {
"next_token_ids": next_token_ids, "next_token_ids": next_token_ids,
...@@ -801,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -801,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids=next_pp_outputs["next_token_ids"], next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None, extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None, extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph, can_run_cuda_graph=result.can_run_cuda_graph,
) )
self.process_batch_result_disagg_prefill( self.process_batch_result_disagg_prefill(
...@@ -818,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -818,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
# carry the outputs to the next stage # carry the outputs to the next stage
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
if pp_outputs: if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing # send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict( self.pp_group.send_tensor_dict(
...@@ -838,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -838,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
# send out proxy tensors to the next stage # send out proxy tensors to the next stage
if self.cur_batch: if self.cur_batch:
# FIXME(lsyin): remove this assert
assert result.pp_hidden_states_proxy_tensors.tensors is not None
self.pp_group.send_tensor_dict( self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors, result.pp_hidden_states_proxy_tensors.tensors,
all_gather_group=self.attn_tp_group, all_gather_group=self.attn_tp_group,
) )
......
...@@ -860,10 +860,6 @@ class Req: ...@@ -860,10 +860,6 @@ class Req:
) )
# Batch id
bid = 0
@dataclasses.dataclass @dataclasses.dataclass
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
"""Store all information of a batch on the scheduler.""" """Store all information of a batch on the scheduler."""
...@@ -1829,10 +1825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1829,10 +1825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
) )
global bid
bid += 1
return ModelWorkerBatch( return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
input_ids=self.input_ids, input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices, req_pool_indices=self.req_pool_indices,
...@@ -1952,8 +1945,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1952,8 +1945,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@dataclasses.dataclass @dataclasses.dataclass
class ModelWorkerBatch: class ModelWorkerBatch:
# The batch id
bid: int
# The forward mode # The forward mode
forward_mode: ForwardMode forward_mode: ForwardMode
# The input ids # The input ids
......
...@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache ...@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import (
ForwardBatchOutput,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...@@ -175,7 +179,6 @@ from sglang.srt.utils import ( ...@@ -175,7 +179,6 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_int_env_var, get_int_env_var,
get_zmq_socket, get_zmq_socket,
is_cpu,
kill_itself_when_parent_died, kill_itself_when_parent_died,
numa_bind_to_node, numa_bind_to_node,
point_to_point_pyobj, point_to_point_pyobj,
...@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__) ...@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__)
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
_is_cpu = is_cpu()
@dataclass @dataclass
class GenerationBatchResult: class GenerationBatchResult:
logits_output: Optional[LogitsProcessorOutput] logits_output: Optional[LogitsProcessorOutput]
pp_hidden_states_proxy_tensors: Optional[torch.Tensor] pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
next_token_ids: Optional[List[int]] next_token_ids: Optional[List[int]]
can_run_cuda_graph: bool
# For output processing
extend_input_len_per_req: List[int] extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int] extend_logprob_start_len_per_req: List[int]
bid: int
can_run_cuda_graph: bool @classmethod
def from_forward_batch_output(
cls,
forward_batch_output: ForwardBatchOutput,
extend_input_len_per_req: List[int],
extend_logprob_start_len_per_req: List[int],
):
# TODO(lsyin): remove this workaround logic and try to unify output classes
return cls(
logits_output=forward_batch_output.logits_output,
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
next_token_ids=forward_batch_output.next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
)
@classmethod
def from_pp_proxy(
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
):
# TODO(lsyin): also simplify this logic
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
proxy_dict = next_pp_outputs.tensors
return cls(
logits_output=logits_output,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
extend_logprob_start_len_per_req=proxy_dict.get(
"extend_logprob_start_len_per_req", None
),
can_run_cuda_graph=can_run_cuda_graph,
)
@dataclass @dataclass
class EmbeddingBatchResult: class EmbeddingBatchResult:
embeddings: torch.Tensor embeddings: torch.Tensor
bid: int
class Scheduler( class Scheduler(
...@@ -403,6 +441,12 @@ class Scheduler( ...@@ -403,6 +441,12 @@ class Scheduler(
else: else:
self.draft_worker = None self.draft_worker = None
# Dispatch the model worker
if self.spec_algorithm.is_none():
self.model_worker = self.tp_worker
else:
self.model_worker = self.draft_worker
# Get token and memory info from the model worker # Get token and memory info from the model worker
( (
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -959,7 +1003,6 @@ class Scheduler( ...@@ -959,7 +1003,6 @@ class Scheduler(
self.running_mbs = [ self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
] ]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None pp_outputs: Optional[PPProxyTensors] = None
while True: while True:
server_is_idle = True server_is_idle = True
...@@ -980,10 +1023,7 @@ class Scheduler( ...@@ -980,10 +1023,7 @@ class Scheduler(
# (last rank) send the outputs to the next step # (last rank) send the outputs to the next step
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
if self.cur_batch: if self.cur_batch:
next_token_ids, bids[mb_id] = ( next_token_ids = result.next_token_ids
result.next_token_ids,
result.bid,
)
if self.cur_batch.return_logprob: if self.cur_batch.return_logprob:
pp_outputs = PPProxyTensors( pp_outputs = PPProxyTensors(
{ {
...@@ -1031,17 +1071,10 @@ class Scheduler( ...@@ -1031,17 +1071,10 @@ class Scheduler(
logits_output = LogitsProcessorOutput(**logits_output_args) logits_output = LogitsProcessorOutput(**logits_output_args)
else: else:
logits_output = None logits_output = None
output_result = GenerationBatchResult(
output_result = GenerationBatchResult.from_pp_proxy(
logits_output=logits_output, logits_output=logits_output,
pp_hidden_states_proxy_tensors=None, next_pp_outputs=next_pp_outputs,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=next_pp_outputs.tensors.get(
"extend_input_len_per_req", None
),
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
"extend_logprob_start_len_per_req", None
),
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph, can_run_cuda_graph=result.can_run_cuda_graph,
) )
self.process_batch_result(mbs[next_mb_id], output_result) self.process_batch_result(mbs[next_mb_id], output_result)
...@@ -1049,8 +1082,6 @@ class Scheduler( ...@@ -1049,8 +1082,6 @@ class Scheduler(
# (not last rank) # (not last rank)
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
# carry the outputs to the next stage # carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing # send the outputs from the last round to let the next stage worker run post processing
if pp_outputs: if pp_outputs:
...@@ -1072,8 +1103,10 @@ class Scheduler( ...@@ -1072,8 +1103,10 @@ class Scheduler(
# send out proxy tensors to the next stage # send out proxy tensors to the next stage
if self.cur_batch: if self.cur_batch:
# FIXME(lsyin): remove this assert
assert result.pp_hidden_states_proxy_tensors.tensors is not None
self.pp_group.send_tensor_dict( self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors, result.pp_hidden_states_proxy_tensors.tensors,
all_gather_group=self.attn_tp_group, all_gather_group=self.attn_tp_group,
) )
...@@ -2016,33 +2049,25 @@ class Scheduler( ...@@ -2016,33 +2049,25 @@ class Scheduler(
# Run forward # Run forward
if self.is_generation: if self.is_generation:
batch_or_worker_batch = batch
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch() # FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch = batch.get_model_worker_batch()
if self.pp_group.is_last_rank: forward_batch_output = self.model_worker.forward_batch_generation(
logits_output, next_token_ids, can_run_cuda_graph = ( batch_or_worker_batch
self.tp_worker.forward_batch_generation(model_worker_batch)
) )
else:
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = ( if not self.spec_algorithm.is_none():
self.tp_worker.forward_batch_generation(model_worker_batch) # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self.udpate_spec_metrics(
batch.batch_size(), forward_batch_output.num_accepted_tokens
) )
bid = model_worker_batch.bid
else:
(
logits_output,
next_token_ids,
bid,
num_accepted_tokens,
can_run_cuda_graph,
) = self.draft_worker.forward_batch_speculative_generation(batch)
bs = batch.batch_size()
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens
if self.pp_group.is_last_rank: # update batch's output ids
batch.output_ids = next_token_ids batch.output_ids = forward_batch_output.next_token_ids
# These 2 values are needed for processing the output, but the values can be # These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that # modified by overlap schedule. So we have to copy them here so that
...@@ -2051,6 +2076,7 @@ class Scheduler( ...@@ -2051,6 +2076,7 @@ class Scheduler(
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs] extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
else: else:
extend_input_len_per_req = None extend_input_len_per_req = None
if batch.return_logprob: if batch.return_logprob:
extend_logprob_start_len_per_req = [ extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs req.extend_logprob_start_len for req in batch.reqs
...@@ -2058,25 +2084,15 @@ class Scheduler( ...@@ -2058,25 +2084,15 @@ class Scheduler(
else: else:
extend_logprob_start_len_per_req = None extend_logprob_start_len_per_req = None
ret = GenerationBatchResult( return GenerationBatchResult.from_forward_batch_output(
logits_output=logits_output if self.pp_group.is_last_rank else None, forward_batch_output=forward_batch_output,
pp_hidden_states_proxy_tensors=(
pp_hidden_states_proxy_tensors
if not self.pp_group.is_last_rank
else None
),
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
extend_input_len_per_req=extend_input_len_per_req, extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=bid,
can_run_cuda_graph=can_run_cuda_graph,
) )
else: # embedding or reward model else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult( ret = EmbeddingBatchResult(embeddings=embeddings)
embeddings=embeddings, bid=model_worker_batch.bid
)
return ret return ret
def process_batch_result( def process_batch_result(
......
...@@ -80,6 +80,11 @@ class SchedulerMetricsMixin: ...@@ -80,6 +80,11 @@ class SchedulerMetricsMixin:
kv_events_config, self.attn_dp_rank kv_events_config, self.attn_dp_rank
) )
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
self.spec_num_total_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens
def log_prefill_stats( def log_prefill_stats(
self: Scheduler, self: Scheduler,
adder: PrefillAdder, adder: PrefillAdder,
......
...@@ -173,8 +173,7 @@ class SchedulerOutputProcessorMixin: ...@@ -173,8 +173,7 @@ class SchedulerOutputProcessorMixin:
self.set_next_batch_sampling_info_done(batch) self.set_next_batch_sampling_info_done(batch)
else: # embedding or reward model else: # embedding or reward model
embeddings, bid = result.embeddings, result.bid embeddings = result.embeddings.tolist()
embeddings = embeddings.tolist()
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
......
...@@ -43,7 +43,11 @@ from sglang.srt.managers.io_struct import ( ...@@ -43,7 +43,11 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardBatchOutput,
PPProxyTensors,
)
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -234,9 +238,7 @@ class TpModelWorker: ...@@ -234,9 +238,7 @@ class TpModelWorker:
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
skip_sample: bool = False, skip_sample: bool = False,
) -> Tuple[ ) -> ForwardBatchOutput:
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
]:
# update the consumer index of hicache to the running batch # update the consumer index of hicache to the running batch
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
...@@ -271,13 +273,20 @@ class TpModelWorker: ...@@ -271,13 +273,20 @@ class TpModelWorker:
else: else:
next_token_ids = self.model_runner.sample(logits_output, forward_batch) next_token_ids = self.model_runner.sample(logits_output, forward_batch)
return logits_output, next_token_ids, can_run_cuda_graph return ForwardBatchOutput(
logits_output=logits_output,
next_token_ids=next_token_ids,
can_run_cuda_graph=can_run_cuda_graph,
)
else: else:
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward( pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
forward_batch, forward_batch,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
return pp_proxy_tensors.tensors, None, can_run_cuda_graph return ForwardBatchOutput(
pp_proxy_tensors=pp_proxy_tensors,
can_run_cuda_graph=can_run_cuda_graph,
)
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
......
...@@ -39,6 +39,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -39,6 +39,7 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode from sglang.srt.utils import DynamicGradMode
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -160,13 +161,17 @@ class TpModelWorkerClient: ...@@ -160,13 +161,17 @@ class TpModelWorkerClient:
self.future_map.resolve_future(model_worker_batch) self.future_map.resolve_future(model_worker_batch)
# Run forward # Run forward
logits_output, next_token_ids, can_run_cuda_graph = ( forward_batch_output = self.worker.forward_batch_generation(
self.worker.forward_batch_generation(
model_worker_batch, model_worker_batch,
model_worker_batch.launch_done, model_worker_batch.launch_done,
# Skip sampling for prefill-only requests # Skip sampling for prefill-only requests
skip_sample=model_worker_batch.is_prefill_only, skip_sample=model_worker_batch.is_prefill_only,
) )
logits_output, next_token_ids, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
forward_batch_output.can_run_cuda_graph,
) )
# Update the future token ids map # Update the future token ids map
...@@ -227,7 +232,7 @@ class TpModelWorkerClient: ...@@ -227,7 +232,7 @@ class TpModelWorkerClient:
def forward_batch_generation( def forward_batch_generation(
self, model_worker_batch: ModelWorkerBatch self, model_worker_batch: ModelWorkerBatch
) -> Tuple[None, torch.Tensor, bool]: ) -> ForwardBatchOutput:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info sampling_info = model_worker_batch.sampling_info
sampling_info.update_penalties() sampling_info.update_penalties()
...@@ -250,7 +255,10 @@ class TpModelWorkerClient: ...@@ -250,7 +255,10 @@ class TpModelWorkerClient:
future_next_token_ids = self.future_map.update_next_future( future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs cur_future_map_ct, bs
) )
return None, future_next_token_ids, False return ForwardBatchOutput(
next_token_ids=future_next_token_ids,
can_run_cuda_graph=False,
)
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req) success, message = self.worker.update_weights_from_disk(recv_req)
......
...@@ -2,11 +2,10 @@ from __future__ import annotations ...@@ -2,11 +2,10 @@ from __future__ import annotations
import logging import logging
import multiprocessing as mp import multiprocessing as mp
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -900,6 +900,17 @@ class ForwardBatch: ...@@ -900,6 +900,17 @@ class ForwardBatch:
return self.tbo_split_seq_index is not None return self.tbo_split_seq_index is not None
@dataclass
class ForwardBatchOutput:
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
# need to be more organized
logits_output: Optional[torch.Tensor] = None
next_token_ids: Optional[torch.Tensor] = None
num_accepted_tokens: Optional[int] = None
pp_proxy_tensors: Optional[PPProxyTensors] = None
can_run_cuda_graph: bool = False
def enable_num_token_non_padded(server_args): def enable_num_token_non_padded(server_args):
return get_moe_expert_parallel_world_size() > 1 return get_moe_expert_parallel_world_size() > 1
......
...@@ -14,7 +14,6 @@ from sglang.srt.distributed import ( ...@@ -14,7 +14,6 @@ from sglang.srt.distributed import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.mm_utils import embed_mm_inputs
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
ScheduleBatch, ScheduleBatch,
get_last_loc, get_last_loc,
...@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker ...@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
ForwardBatchOutput,
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -422,9 +422,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -422,9 +422,7 @@ class EAGLEWorker(TpModelWorker):
def draft_model_runner(self): def draft_model_runner(self):
return self.model_runner return self.model_runner
def forward_batch_speculative_generation( def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
"""Run speculative decoding forward. """Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that NOTE: Many states of batch is modified as you go through. It is not guaranteed that
...@@ -437,14 +435,19 @@ class EAGLEWorker(TpModelWorker): ...@@ -437,14 +435,19 @@ class EAGLEWorker(TpModelWorker):
the batch id (used for overlap schedule), and number of accepted tokens. the batch id (used for overlap schedule), and number of accepted tokens.
""" """
if batch.forward_mode.is_extend() or batch.is_extend_in_batch: if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
logits_output, next_token_ids, bid, seq_lens_cpu = ( logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
self.forward_target_extend(batch) batch
) )
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend( self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
) )
return logits_output, next_token_ids, bid, 0, False return ForwardBatchOutput(
logits_output=logits_output,
next_token_ids=next_token_ids,
num_accepted_tokens=0,
can_run_cuda_graph=False,
)
else: else:
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch) spec_info = self.draft(batch)
...@@ -462,12 +465,11 @@ class EAGLEWorker(TpModelWorker): ...@@ -462,12 +465,11 @@ class EAGLEWorker(TpModelWorker):
# decode is not finished # decode is not finished
self.forward_draft_extend_after_decode(batch) self.forward_draft_extend_after_decode(batch)
return ( return ForwardBatchOutput(
logits_output, logits_output=logits_output,
verify_output.verified_id, next_token_ids=verify_output.verified_id,
model_worker_batch.bid, num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
sum(verify_output.accept_length_per_req_cpu), can_run_cuda_graph=can_run_cuda_graph,
can_run_cuda_graph,
) )
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
...@@ -499,19 +501,21 @@ class EAGLEWorker(TpModelWorker): ...@@ -499,19 +501,21 @@ class EAGLEWorker(TpModelWorker):
Returns: Returns:
logits_output: The output of logits. It will contain the full hidden states. logits_output: The output of logits. It will contain the full hidden states.
next_token_ids: Next token ids generated. next_token_ids: Next token ids generated.
bid: The model batch ID. Used for overlap schedule.
""" """
# Forward with the target model and get hidden states. # Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model. # We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
logits_output, next_token_ids = (
forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
)
return ( return (
logits_output, logits_output,
next_token_ids, next_token_ids,
model_worker_batch.bid,
model_worker_batch.seq_lens_cpu, model_worker_batch.seq_lens_cpu,
) )
...@@ -811,10 +815,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -811,10 +815,12 @@ class EAGLEWorker(TpModelWorker):
).cpu() ).cpu()
# Forward # Forward
logits_output, _, can_run_cuda_graph = ( forward_batch_output = self.target_worker.forward_batch_generation(
self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True model_worker_batch, skip_sample=True
) )
logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.can_run_cuda_graph,
) )
vocab_mask = None vocab_mask = None
......
...@@ -7,7 +7,7 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask ...@@ -7,7 +7,7 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.ngram_utils import NgramVerifyInput
...@@ -207,17 +207,18 @@ class NGRAMWorker: ...@@ -207,17 +207,18 @@ class NGRAMWorker:
batch_tokens.append(put_ids) batch_tokens.append(put_ids)
self.ngram_cache.batch_put(batch_tokens) self.ngram_cache.batch_put(batch_tokens)
def forward_batch_speculative_generation(self, batch: ScheduleBatch): def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
self._prepare_for_speculative_decoding(batch) self._prepare_for_speculative_decoding(batch)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
bid = model_worker_batch.bid
num_accepted_tokens = 0 num_accepted_tokens = 0
if model_worker_batch.forward_mode.is_target_verify(): if model_worker_batch.forward_mode.is_target_verify():
logits_output, _, can_run_cuda_graph = ( forward_batch_output = self.target_worker.forward_batch_generation(
self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True model_worker_batch, skip_sample=True
) )
logits_output, can_run_cuda_graph = (
forward_batch_output.logits_output,
forward_batch_output.can_run_cuda_graph,
) )
verify_input = model_worker_batch.spec_info verify_input = model_worker_batch.spec_info
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify( logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
...@@ -227,14 +228,18 @@ class NGRAMWorker: ...@@ -227,14 +228,18 @@ class NGRAMWorker:
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
else: else:
forward_batch_output = self.target_worker.forward_batch_generation(
model_worker_batch
)
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(model_worker_batch) forward_batch_output.logits_output,
forward_batch_output.next_token_ids,
forward_batch_output.can_run_cuda_graph,
) )
return ( return ForwardBatchOutput(
logits_output, logits_output=logits_output,
next_token_ids, next_token_ids=next_token_ids,
bid, num_accepted_tokens=num_accepted_tokens,
num_accepted_tokens, can_run_cuda_graph=can_run_cuda_graph,
can_run_cuda_graph,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment