Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
458611de
Unverified
Commit
458611de
authored
Oct 03, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 03, 2025
Browse files
Unify forward output datastructure (#11124)
parent
3511b370
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
181 additions
and
136 deletions
+181
-136
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+3
-2
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+4
-9
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+80
-64
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+5
-0
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+1
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+15
-6
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+16
-8
python/sglang/srt/managers/utils.py
python/sglang/srt/managers/utils.py
+1
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+11
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+26
-20
python/sglang/srt/speculative/ngram_worker.py
python/sglang/srt/speculative/ngram_worker.py
+19
-14
No files found.
python/sglang/srt/configs/model_config.py
View file @
458611de
...
...
@@ -22,6 +22,7 @@ from typing import List, Optional, Set, Union
import
torch
from
transformers
import
PretrainedConfig
from
sglang.srt.environ
import
envs
from
sglang.srt.hf_transformers_utils
import
(
get_config
,
get_context_length
,
...
...
@@ -31,7 +32,7 @@ from sglang.srt.hf_transformers_utils import (
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
retry
from
sglang.srt.utils
import
is_hip
,
retry
from
sglang.utils
import
is_in_ci
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -237,7 +238,7 @@ class ModelConfig:
f
"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
)
if
(
get_bool_env_var
(
"
SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN
"
)
envs
.
SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN
.
get
(
)
or
is_in_ci
()
# FIXME: fix this special case
):
logger
.
warning
(
msg
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
458611de
...
...
@@ -689,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
self
.
running_mbs
=
[
ScheduleBatch
(
reqs
=
[],
batch_is_full
=
False
)
for
_
in
range
(
self
.
pp_size
)
]
bids
=
[
None
]
*
self
.
pp_size
pp_outputs
:
Optional
[
PPProxyTensors
]
=
None
# Either success or failed
...
...
@@ -761,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
# send the outputs to the next step
if
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
next_token_ids
,
bids
[
mb_id
]
=
(
result
.
next_token_ids
,
result
.
bid
,
)
next_token_ids
=
result
.
next_token_ids
pp_outputs
=
PPProxyTensors
(
{
"next_token_ids"
:
next_token_ids
,
...
...
@@ -801,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids
=
next_pp_outputs
[
"next_token_ids"
],
extend_input_len_per_req
=
None
,
extend_logprob_start_len_per_req
=
None
,
bid
=
bids
[
next_mb_id
],
can_run_cuda_graph
=
result
.
can_run_cuda_graph
,
)
self
.
process_batch_result_disagg_prefill
(
...
...
@@ -818,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
# carry the outputs to the next stage
if
not
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
bids
[
mb_id
]
=
result
.
bid
if
pp_outputs
:
# send the outputs from the last round to let the next stage worker run post processing
self
.
pp_group
.
send_tensor_dict
(
...
...
@@ -838,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
# send out proxy tensors to the next stage
if
self
.
cur_batch
:
# FIXME(lsyin): remove this assert
assert
result
.
pp_hidden_states_proxy_tensors
.
tensors
is
not
None
self
.
pp_group
.
send_tensor_dict
(
result
.
pp_hidden_states_proxy_tensors
,
result
.
pp_hidden_states_proxy_tensors
.
tensors
,
all_gather_group
=
self
.
attn_tp_group
,
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
458611de
...
...
@@ -860,10 +860,6 @@ class Req:
)
# Batch id
bid
=
0
@
dataclasses
.
dataclass
class
ScheduleBatch
(
ScheduleBatchDisaggregationDecodeMixin
):
"""Store all information of a batch on the scheduler."""
...
...
@@ -1829,10 +1825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_cpu_cache
if
seq_lens_cpu_cache
is
not
None
else
self
.
seq_lens_cpu
)
global
bid
bid
+=
1
return
ModelWorkerBatch
(
bid
=
bid
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
...
...
@@ -1952,8 +1945,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@
dataclasses
.
dataclass
class
ModelWorkerBatch
:
# The batch id
bid
:
int
# The forward mode
forward_mode
:
ForwardMode
# The input ids
...
...
python/sglang/srt/managers/scheduler.py
View file @
458611de
...
...
@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatchOutput
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.parser.reasoning_parser
import
ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
...
@@ -175,7 +179,6 @@ from sglang.srt.utils import (
get_bool_env_var
,
get_int_env_var
,
get_zmq_socket
,
is_cpu
,
kill_itself_when_parent_died
,
numa_bind_to_node
,
point_to_point_pyobj
,
...
...
@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__)
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
GRAMMAR_TIMEOUT
=
float
(
os
.
environ
.
get
(
"SGLANG_GRAMMAR_TIMEOUT"
,
300
))
_is_cpu
=
is_cpu
()
@
dataclass
class
GenerationBatchResult
:
logits_output
:
Optional
[
LogitsProcessorOutput
]
pp_hidden_states_proxy_tensors
:
Optional
[
torch
.
Tensor
]
pp_hidden_states_proxy_tensors
:
Optional
[
PPProxy
Tensor
s
]
next_token_ids
:
Optional
[
List
[
int
]]
can_run_cuda_graph
:
bool
# For output processing
extend_input_len_per_req
:
List
[
int
]
extend_logprob_start_len_per_req
:
List
[
int
]
bid
:
int
can_run_cuda_graph
:
bool
@
classmethod
def
from_forward_batch_output
(
cls
,
forward_batch_output
:
ForwardBatchOutput
,
extend_input_len_per_req
:
List
[
int
],
extend_logprob_start_len_per_req
:
List
[
int
],
):
# TODO(lsyin): remove this workaround logic and try to unify output classes
return
cls
(
logits_output
=
forward_batch_output
.
logits_output
,
pp_hidden_states_proxy_tensors
=
forward_batch_output
.
pp_proxy_tensors
,
next_token_ids
=
forward_batch_output
.
next_token_ids
,
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
can_run_cuda_graph
=
forward_batch_output
.
can_run_cuda_graph
,
)
@
classmethod
def
from_pp_proxy
(
cls
,
logits_output
,
next_pp_outputs
:
PPProxyTensors
,
can_run_cuda_graph
):
# TODO(lsyin): also simplify this logic
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
proxy_dict
=
next_pp_outputs
.
tensors
return
cls
(
logits_output
=
logits_output
,
pp_hidden_states_proxy_tensors
=
None
,
next_token_ids
=
next_pp_outputs
[
"next_token_ids"
],
extend_input_len_per_req
=
proxy_dict
.
get
(
"extend_input_len_per_req"
,
None
),
extend_logprob_start_len_per_req
=
proxy_dict
.
get
(
"extend_logprob_start_len_per_req"
,
None
),
can_run_cuda_graph
=
can_run_cuda_graph
,
)
@
dataclass
class
EmbeddingBatchResult
:
embeddings
:
torch
.
Tensor
bid
:
int
class
Scheduler
(
...
...
@@ -403,6 +441,12 @@ class Scheduler(
else
:
self
.
draft_worker
=
None
# Dispatch the model worker
if
self
.
spec_algorithm
.
is_none
():
self
.
model_worker
=
self
.
tp_worker
else
:
self
.
model_worker
=
self
.
draft_worker
# Get token and memory info from the model worker
(
self
.
max_total_num_tokens
,
...
...
@@ -959,7 +1003,6 @@ class Scheduler(
self
.
running_mbs
=
[
ScheduleBatch
(
reqs
=
[],
batch_is_full
=
False
)
for
_
in
range
(
self
.
pp_size
)
]
bids
=
[
None
]
*
self
.
pp_size
pp_outputs
:
Optional
[
PPProxyTensors
]
=
None
while
True
:
server_is_idle
=
True
...
...
@@ -980,10 +1023,7 @@ class Scheduler(
# (last rank) send the outputs to the next step
if
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
next_token_ids
,
bids
[
mb_id
]
=
(
result
.
next_token_ids
,
result
.
bid
,
)
next_token_ids
=
result
.
next_token_ids
if
self
.
cur_batch
.
return_logprob
:
pp_outputs
=
PPProxyTensors
(
{
...
...
@@ -1031,17 +1071,10 @@ class Scheduler(
logits_output
=
LogitsProcessorOutput
(
**
logits_output_args
)
else
:
logits_output
=
None
output_result
=
GenerationBatchResult
(
output_result
=
GenerationBatchResult
.
from_pp_proxy
(
logits_output
=
logits_output
,
pp_hidden_states_proxy_tensors
=
None
,
next_token_ids
=
next_pp_outputs
[
"next_token_ids"
],
extend_input_len_per_req
=
next_pp_outputs
.
tensors
.
get
(
"extend_input_len_per_req"
,
None
),
extend_logprob_start_len_per_req
=
next_pp_outputs
.
tensors
.
get
(
"extend_logprob_start_len_per_req"
,
None
),
bid
=
bids
[
next_mb_id
],
next_pp_outputs
=
next_pp_outputs
,
can_run_cuda_graph
=
result
.
can_run_cuda_graph
,
)
self
.
process_batch_result
(
mbs
[
next_mb_id
],
output_result
)
...
...
@@ -1049,8 +1082,6 @@ class Scheduler(
# (not last rank)
if
not
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
bids
[
mb_id
]
=
result
.
bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if
pp_outputs
:
...
...
@@ -1072,8 +1103,10 @@ class Scheduler(
# send out proxy tensors to the next stage
if
self
.
cur_batch
:
# FIXME(lsyin): remove this assert
assert
result
.
pp_hidden_states_proxy_tensors
.
tensors
is
not
None
self
.
pp_group
.
send_tensor_dict
(
result
.
pp_hidden_states_proxy_tensors
,
result
.
pp_hidden_states_proxy_tensors
.
tensors
,
all_gather_group
=
self
.
attn_tp_group
,
)
...
...
@@ -2016,33 +2049,25 @@ class Scheduler(
# Run forward
if
self
.
is_generation
:
batch_or_worker_batch
=
batch
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch
=
batch
.
get_model_worker_batch
()
if
self
.
pp_group
.
is_last_rank
:
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
else
:
pp_hidden_states_proxy_tensors
,
_
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
bid
=
model_worker_batch
.
bid
else
:
(
logits_output
,
next_token_ids
,
bid
,
num_accepted_tokens
,
can_run_cuda_graph
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
bs
=
batch
.
batch_size
()
self
.
spec_num_total_accepted_tokens
+=
num_accepted_tokens
+
bs
self
.
spec_num_total_forward_ct
+=
bs
self
.
num_generated_tokens
+=
num_accepted_tokens
if
self
.
pp_group
.
is_last_rank
:
batch
.
output_ids
=
next_token_ids
forward_batch_output
=
self
.
model_worker
.
forward_batch_generation
(
batch_or_worker_batch
)
if
not
self
.
spec_algorithm
.
is_none
():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self
.
udpate_spec_metrics
(
batch
.
batch_size
(),
forward_batch_output
.
num_accepted_tokens
)
# update batch's output ids
batch
.
output_ids
=
forward_batch_output
.
next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
...
...
@@ -2051,6 +2076,7 @@ class Scheduler(
extend_input_len_per_req
=
[
req
.
extend_input_len
for
req
in
batch
.
reqs
]
else
:
extend_input_len_per_req
=
None
if
batch
.
return_logprob
:
extend_logprob_start_len_per_req
=
[
req
.
extend_logprob_start_len
for
req
in
batch
.
reqs
...
...
@@ -2058,25 +2084,15 @@ class Scheduler(
else
:
extend_logprob_start_len_per_req
=
None
ret
=
GenerationBatchResult
(
logits_output
=
logits_output
if
self
.
pp_group
.
is_last_rank
else
None
,
pp_hidden_states_proxy_tensors
=
(
pp_hidden_states_proxy_tensors
if
not
self
.
pp_group
.
is_last_rank
else
None
),
next_token_ids
=
next_token_ids
if
self
.
pp_group
.
is_last_rank
else
None
,
return
GenerationBatchResult
.
from_forward_batch_output
(
forward_batch_output
=
forward_batch_output
,
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
bid
=
bid
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
else
:
# embedding or reward model
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
,
bid
=
model_worker_batch
.
bid
)
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
)
return
ret
def
process_batch_result
(
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
View file @
458611de
...
...
@@ -80,6 +80,11 @@ class SchedulerMetricsMixin:
kv_events_config
,
self
.
attn_dp_rank
)
def
udpate_spec_metrics
(
self
,
bs
:
int
,
num_accepted_tokens
:
int
):
self
.
spec_num_total_accepted_tokens
+=
num_accepted_tokens
+
bs
self
.
spec_num_total_forward_ct
+=
bs
self
.
num_generated_tokens
+=
num_accepted_tokens
def
log_prefill_stats
(
self
:
Scheduler
,
adder
:
PrefillAdder
,
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
458611de
...
...
@@ -173,8 +173,7 @@ class SchedulerOutputProcessorMixin:
self
.
set_next_batch_sampling_info_done
(
batch
)
else
:
# embedding or reward model
embeddings
,
bid
=
result
.
embeddings
,
result
.
bid
embeddings
=
embeddings
.
tolist
()
embeddings
=
result
.
embeddings
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
458611de
...
...
@@ -43,7 +43,11 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardBatchOutput
,
PPProxyTensors
,
)
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -234,9 +238,7 @@ class TpModelWorker:
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
bool
]:
)
->
ForwardBatchOutput
:
# update the consumer index of hicache to the running batch
self
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
...
...
@@ -271,13 +273,20 @@ class TpModelWorker:
else
:
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
return
ForwardBatchOutput
(
logits_output
=
logits_output
,
next_token_ids
=
next_token_ids
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
else
:
pp_proxy_tensors
,
can_run_cuda_graph
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
return
pp_proxy_tensors
.
tensors
,
None
,
can_run_cuda_graph
return
ForwardBatchOutput
(
pp_proxy_tensors
=
pp_proxy_tensors
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
458611de
...
...
@@ -39,6 +39,7 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.managers.overlap_utils
import
FutureMap
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatchOutput
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
DynamicGradMode
from
sglang.utils
import
get_exception_traceback
...
...
@@ -160,13 +161,17 @@ class TpModelWorkerClient:
self
.
future_map
.
resolve_future
(
model_worker_batch
)
# Run forward
forward_batch_output
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
model_worker_batch
.
launch_done
,
# Skip sampling for prefill-only requests
skip_sample
=
model_worker_batch
.
is_prefill_only
,
)
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
model_worker_batch
.
launch_done
,
# Skip sampling for prefill-only requests
skip_sample
=
model_worker_batch
.
is_prefill_only
,
)
forward_batch_output
.
logits_output
,
forward_batch_output
.
next_token_ids
,
forward_batch_output
.
can_run_cuda_graph
,
)
# Update the future token ids map
...
...
@@ -227,7 +232,7 @@ class TpModelWorkerClient:
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
)
->
Tuple
[
None
,
torch
.
Tensor
,
bool
]
:
)
->
ForwardBatchOutput
:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info
=
model_worker_batch
.
sampling_info
sampling_info
.
update_penalties
()
...
...
@@ -250,7 +255,10 @@ class TpModelWorkerClient:
future_next_token_ids
=
self
.
future_map
.
update_next_future
(
cur_future_map_ct
,
bs
)
return
None
,
future_next_token_ids
,
False
return
ForwardBatchOutput
(
next_token_ids
=
future_next_token_ids
,
can_run_cuda_graph
=
False
,
)
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_disk
(
recv_req
)
...
...
python/sglang/srt/managers/utils.py
View file @
458611de
...
...
@@ -2,11 +2,10 @@ from __future__ import annotations
import
logging
import
multiprocessing
as
mp
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.model_executor.forward_batch_info
import
PPProxyTensors
if
TYPE_CHECKING
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
458611de
...
...
@@ -900,6 +900,17 @@ class ForwardBatch:
return
self
.
tbo_split_seq_index
is
not
None
@
dataclass
class
ForwardBatchOutput
:
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
# need to be more organized
logits_output
:
Optional
[
torch
.
Tensor
]
=
None
next_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
num_accepted_tokens
:
Optional
[
int
]
=
None
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
can_run_cuda_graph
:
bool
=
False
def
enable_num_token_non_padded
(
server_args
):
return
get_moe_expert_parallel_world_size
()
>
1
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
458611de
...
...
@@ -14,7 +14,6 @@ from sglang.srt.distributed import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.mm_utils
import
embed_mm_inputs
from
sglang.srt.managers.schedule_batch
import
(
ScheduleBatch
,
get_last_loc
,
...
...
@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardBatchOutput
,
ForwardMode
,
)
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -422,9 +422,7 @@ class EAGLEWorker(TpModelWorker):
def
draft_model_runner
(
self
):
return
self
.
model_runner
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
,
torch
.
Tensor
,
int
,
int
,
bool
]:
def
forward_batch_generation
(
self
,
batch
:
ScheduleBatch
)
->
ForwardBatchOutput
:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
...
...
@@ -437,14 +435,19 @@ class EAGLEWorker(TpModelWorker):
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if
batch
.
forward_mode
.
is_extend
()
or
batch
.
is_extend_in_batch
:
logits_output
,
next_token_ids
,
bid
,
seq_lens_cpu
=
(
self
.
forward_target_extend
(
batch
)
logits_output
,
next_token_ids
,
seq_lens_cpu
=
self
.
forward_target_extend
(
batch
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
,
seq_lens_cpu
)
return
logits_output
,
next_token_ids
,
bid
,
0
,
False
return
ForwardBatchOutput
(
logits_output
=
logits_output
,
next_token_ids
=
next_token_ids
,
num_accepted_tokens
=
0
,
can_run_cuda_graph
=
False
,
)
else
:
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
spec_info
=
self
.
draft
(
batch
)
...
...
@@ -462,12 +465,11 @@ class EAGLEWorker(TpModelWorker):
# decode is not finished
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
logits_output
,
verify_output
.
verified_id
,
model_worker_batch
.
bid
,
sum
(
verify_output
.
accept_length_per_req_cpu
),
can_run_cuda_graph
,
return
ForwardBatchOutput
(
logits_output
=
logits_output
,
next_token_ids
=
verify_output
.
verified_id
,
num_accepted_tokens
=
sum
(
verify_output
.
accept_length_per_req_cpu
),
can_run_cuda_graph
=
can_run_cuda_graph
,
)
def
check_forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
...
...
@@ -499,19 +501,21 @@ class EAGLEWorker(TpModelWorker):
Returns:
logits_output: The output of logits. It will contain the full hidden states.
next_token_ids: Next token ids generated.
bid: The model batch ID. Used for overlap schedule.
"""
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
logits_output
,
next_token_ids
,
_
=
self
.
target_worker
.
forward_batch_generation
(
forward_batch_output
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
logits_output
,
next_token_ids
=
(
forward_batch_output
.
logits_output
,
forward_batch_output
.
next_token_ids
,
)
return
(
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
model_worker_batch
.
seq_lens_cpu
,
)
...
...
@@ -811,10 +815,12 @@ class EAGLEWorker(TpModelWorker):
).
cpu
()
# Forward
logits_output
,
_
,
can_run_cuda_graph
=
(
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
forward_batch_output
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
logits_output
,
can_run_cuda_graph
=
(
forward_batch_output
.
logits_output
,
forward_batch_output
.
can_run_cuda_graph
,
)
vocab_mask
=
None
...
...
python/sglang/srt/speculative/ngram_worker.py
View file @
458611de
...
...
@@ -7,7 +7,7 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatchOutput
,
ForwardMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.cpp_ngram.ngram_cache
import
NgramCache
from
sglang.srt.speculative.ngram_utils
import
NgramVerifyInput
...
...
@@ -207,17 +207,18 @@ class NGRAMWorker:
batch_tokens
.
append
(
put_ids
)
self
.
ngram_cache
.
batch_put
(
batch_tokens
)
def
forward_batch_
speculative_
generation
(
self
,
batch
:
ScheduleBatch
):
def
forward_batch_generation
(
self
,
batch
:
ScheduleBatch
)
->
ForwardBatchOutput
:
self
.
_prepare_for_speculative_decoding
(
batch
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
bid
=
model_worker_batch
.
bid
num_accepted_tokens
=
0
if
model_worker_batch
.
forward_mode
.
is_target_verify
():
logits_output
,
_
,
can_run_cuda_graph
=
(
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
forward_batch_output
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
logits_output
,
can_run_cuda_graph
=
(
forward_batch_output
.
logits_output
,
forward_batch_output
.
can_run_cuda_graph
,
)
verify_input
=
model_worker_batch
.
spec_info
logits_output
,
next_token_ids
,
num_accepted_tokens
=
verify_input
.
verify
(
...
...
@@ -227,14 +228,18 @@ class NGRAMWorker:
batch
.
forward_mode
=
ForwardMode
.
DECODE
else
:
forward_batch_output
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
forward_batch_output
.
logits_output
,
forward_batch_output
.
next_token_ids
,
forward_batch_output
.
can_run_cuda_graph
,
)
return
(
logits_output
,
next_token_ids
,
bid
,
num_accepted_tokens
,
can_run_cuda_graph
,
return
ForwardBatchOutput
(
logits_output
=
logits_output
,
next_token_ids
=
next_token_ids
,
num_accepted_tokens
=
num_accepted_tokens
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment