Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
458611de
"src/diffusers/pipelines/ltx/pipeline_ltx_condition.py" did not exist on "2b443a5d621bd65f5cbf854195aef29cedd24058"
Unverified
Commit
458611de
authored
Oct 03, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 03, 2025
Browse files
Unify forward output datastructure (#11124)
parent
3511b370
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
181 additions
and
136 deletions
+181
-136
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+3
-2
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+4
-9
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+80
-64
python/sglang/srt/managers/scheduler_metrics_mixin.py
python/sglang/srt/managers/scheduler_metrics_mixin.py
+5
-0
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+1
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+15
-6
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+16
-8
python/sglang/srt/managers/utils.py
python/sglang/srt/managers/utils.py
+1
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+11
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+26
-20
python/sglang/srt/speculative/ngram_worker.py
python/sglang/srt/speculative/ngram_worker.py
+19
-14
No files found.
python/sglang/srt/configs/model_config.py
View file @
458611de
...
...
@@ -22,6 +22,7 @@ from typing import List, Optional, Set, Union
import
torch
from
transformers
import
PretrainedConfig
from
sglang.srt.environ
import
envs
from
sglang.srt.hf_transformers_utils
import
(
get_config
,
get_context_length
,
...
...
@@ -31,7 +32,7 @@ from sglang.srt.hf_transformers_utils import (
)
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
retry
from
sglang.srt.utils
import
is_hip
,
retry
from
sglang.utils
import
is_in_ci
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -237,7 +238,7 @@ class ModelConfig:
f
"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
)
if
(
get_bool_env_var
(
"
SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN
"
)
envs
.
SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN
.
get
(
)
or
is_in_ci
()
# FIXME: fix this special case
):
logger
.
warning
(
msg
)
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
458611de
...
...
@@ -689,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
self
.
running_mbs
=
[
ScheduleBatch
(
reqs
=
[],
batch_is_full
=
False
)
for
_
in
range
(
self
.
pp_size
)
]
bids
=
[
None
]
*
self
.
pp_size
pp_outputs
:
Optional
[
PPProxyTensors
]
=
None
# Either success or failed
...
...
@@ -761,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
# send the outputs to the next step
if
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
next_token_ids
,
bids
[
mb_id
]
=
(
result
.
next_token_ids
,
result
.
bid
,
)
next_token_ids
=
result
.
next_token_ids
pp_outputs
=
PPProxyTensors
(
{
"next_token_ids"
:
next_token_ids
,
...
...
@@ -801,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids
=
next_pp_outputs
[
"next_token_ids"
],
extend_input_len_per_req
=
None
,
extend_logprob_start_len_per_req
=
None
,
bid
=
bids
[
next_mb_id
],
can_run_cuda_graph
=
result
.
can_run_cuda_graph
,
)
self
.
process_batch_result_disagg_prefill
(
...
...
@@ -818,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
# carry the outputs to the next stage
if
not
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
bids
[
mb_id
]
=
result
.
bid
if
pp_outputs
:
# send the outputs from the last round to let the next stage worker run post processing
self
.
pp_group
.
send_tensor_dict
(
...
...
@@ -838,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
# send out proxy tensors to the next stage
if
self
.
cur_batch
:
# FIXME(lsyin): remove this assert
assert
result
.
pp_hidden_states_proxy_tensors
.
tensors
is
not
None
self
.
pp_group
.
send_tensor_dict
(
result
.
pp_hidden_states_proxy_tensors
,
result
.
pp_hidden_states_proxy_tensors
.
tensors
,
all_gather_group
=
self
.
attn_tp_group
,
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
458611de
...
...
@@ -860,10 +860,6 @@ class Req:
)
# Batch id
bid
=
0
@
dataclasses
.
dataclass
class
ScheduleBatch
(
ScheduleBatchDisaggregationDecodeMixin
):
"""Store all information of a batch on the scheduler."""
...
...
@@ -1829,10 +1825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_cpu_cache
if
seq_lens_cpu_cache
is
not
None
else
self
.
seq_lens_cpu
)
global
bid
bid
+=
1
return
ModelWorkerBatch
(
bid
=
bid
,
forward_mode
=
self
.
forward_mode
,
input_ids
=
self
.
input_ids
,
req_pool_indices
=
self
.
req_pool_indices
,
...
...
@@ -1952,8 +1945,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
@
dataclasses
.
dataclass
class
ModelWorkerBatch
:
# The batch id
bid
:
int
# The forward mode
forward_mode
:
ForwardMode
# The input ids
...
...
python/sglang/srt/managers/scheduler.py
View file @
458611de
...
...
@@ -150,7 +150,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatchOutput
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.parser.reasoning_parser
import
ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
...
@@ -175,7 +179,6 @@ from sglang.srt.utils import (
get_bool_env_var
,
get_int_env_var
,
get_zmq_socket
,
is_cpu
,
kill_itself_when_parent_died
,
numa_bind_to_node
,
point_to_point_pyobj
,
...
...
@@ -194,24 +197,59 @@ logger = logging.getLogger(__name__)
TEST_RETRACT
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
GRAMMAR_TIMEOUT
=
float
(
os
.
environ
.
get
(
"SGLANG_GRAMMAR_TIMEOUT"
,
300
))
_is_cpu
=
is_cpu
()
@
dataclass
class
GenerationBatchResult
:
logits_output
:
Optional
[
LogitsProcessorOutput
]
pp_hidden_states_proxy_tensors
:
Optional
[
torch
.
Tensor
]
pp_hidden_states_proxy_tensors
:
Optional
[
PPProxy
Tensor
s
]
next_token_ids
:
Optional
[
List
[
int
]]
can_run_cuda_graph
:
bool
# For output processing
extend_input_len_per_req
:
List
[
int
]
extend_logprob_start_len_per_req
:
List
[
int
]
bid
:
int
can_run_cuda_graph
:
bool
@
classmethod
def
from_forward_batch_output
(
cls
,
forward_batch_output
:
ForwardBatchOutput
,
extend_input_len_per_req
:
List
[
int
],
extend_logprob_start_len_per_req
:
List
[
int
],
):
# TODO(lsyin): remove this workaround logic and try to unify output classes
return
cls
(
logits_output
=
forward_batch_output
.
logits_output
,
pp_hidden_states_proxy_tensors
=
forward_batch_output
.
pp_proxy_tensors
,
next_token_ids
=
forward_batch_output
.
next_token_ids
,
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
can_run_cuda_graph
=
forward_batch_output
.
can_run_cuda_graph
,
)
@
classmethod
def
from_pp_proxy
(
cls
,
logits_output
,
next_pp_outputs
:
PPProxyTensors
,
can_run_cuda_graph
):
# TODO(lsyin): also simplify this logic
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
proxy_dict
=
next_pp_outputs
.
tensors
return
cls
(
logits_output
=
logits_output
,
pp_hidden_states_proxy_tensors
=
None
,
next_token_ids
=
next_pp_outputs
[
"next_token_ids"
],
extend_input_len_per_req
=
proxy_dict
.
get
(
"extend_input_len_per_req"
,
None
),
extend_logprob_start_len_per_req
=
proxy_dict
.
get
(
"extend_logprob_start_len_per_req"
,
None
),
can_run_cuda_graph
=
can_run_cuda_graph
,
)
@
dataclass
class
EmbeddingBatchResult
:
embeddings
:
torch
.
Tensor
bid
:
int
class
Scheduler
(
...
...
@@ -403,6 +441,12 @@ class Scheduler(
else
:
self
.
draft_worker
=
None
# Dispatch the model worker
if
self
.
spec_algorithm
.
is_none
():
self
.
model_worker
=
self
.
tp_worker
else
:
self
.
model_worker
=
self
.
draft_worker
# Get token and memory info from the model worker
(
self
.
max_total_num_tokens
,
...
...
@@ -959,7 +1003,6 @@ class Scheduler(
self
.
running_mbs
=
[
ScheduleBatch
(
reqs
=
[],
batch_is_full
=
False
)
for
_
in
range
(
self
.
pp_size
)
]
bids
=
[
None
]
*
self
.
pp_size
pp_outputs
:
Optional
[
PPProxyTensors
]
=
None
while
True
:
server_is_idle
=
True
...
...
@@ -980,10 +1023,7 @@ class Scheduler(
# (last rank) send the outputs to the next step
if
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
next_token_ids
,
bids
[
mb_id
]
=
(
result
.
next_token_ids
,
result
.
bid
,
)
next_token_ids
=
result
.
next_token_ids
if
self
.
cur_batch
.
return_logprob
:
pp_outputs
=
PPProxyTensors
(
{
...
...
@@ -1031,17 +1071,10 @@ class Scheduler(
logits_output
=
LogitsProcessorOutput
(
**
logits_output_args
)
else
:
logits_output
=
None
output_result
=
GenerationBatchResult
(
output_result
=
GenerationBatchResult
.
from_pp_proxy
(
logits_output
=
logits_output
,
pp_hidden_states_proxy_tensors
=
None
,
next_token_ids
=
next_pp_outputs
[
"next_token_ids"
],
extend_input_len_per_req
=
next_pp_outputs
.
tensors
.
get
(
"extend_input_len_per_req"
,
None
),
extend_logprob_start_len_per_req
=
next_pp_outputs
.
tensors
.
get
(
"extend_logprob_start_len_per_req"
,
None
),
bid
=
bids
[
next_mb_id
],
next_pp_outputs
=
next_pp_outputs
,
can_run_cuda_graph
=
result
.
can_run_cuda_graph
,
)
self
.
process_batch_result
(
mbs
[
next_mb_id
],
output_result
)
...
...
@@ -1049,8 +1082,6 @@ class Scheduler(
# (not last rank)
if
not
self
.
pp_group
.
is_last_rank
:
if
self
.
cur_batch
:
bids
[
mb_id
]
=
result
.
bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if
pp_outputs
:
...
...
@@ -1072,8 +1103,10 @@ class Scheduler(
# send out proxy tensors to the next stage
if
self
.
cur_batch
:
# FIXME(lsyin): remove this assert
assert
result
.
pp_hidden_states_proxy_tensors
.
tensors
is
not
None
self
.
pp_group
.
send_tensor_dict
(
result
.
pp_hidden_states_proxy_tensors
,
result
.
pp_hidden_states_proxy_tensors
.
tensors
,
all_gather_group
=
self
.
attn_tp_group
,
)
...
...
@@ -2016,33 +2049,25 @@ class Scheduler(
# Run forward
if
self
.
is_generation
:
batch_or_worker_batch
=
batch
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
# FIXME(lsyin): remove this if and finally unify the abstraction
batch_or_worker_batch
=
batch
.
get_model_worker_batch
()
if
self
.
pp_group
.
is_last_rank
:
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
forward_batch_output
=
self
.
model_worker
.
forward_batch_generation
(
batch_or_worker_batch
)
else
:
pp_hidden_states_proxy_tensors
,
_
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
if
not
self
.
spec_algorithm
.
is_none
():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
self
.
udpate_spec_metrics
(
batch
.
batch_size
(),
forward_batch_output
.
num_accepted_tokens
)
bid
=
model_worker_batch
.
bid
else
:
(
logits_output
,
next_token_ids
,
bid
,
num_accepted_tokens
,
can_run_cuda_graph
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
bs
=
batch
.
batch_size
()
self
.
spec_num_total_accepted_tokens
+=
num_accepted_tokens
+
bs
self
.
spec_num_total_forward_ct
+=
bs
self
.
num_generated_tokens
+=
num_accepted_tokens
if
self
.
pp_group
.
is_last_rank
:
batch
.
output_ids
=
next_token_ids
# update batch's output ids
batch
.
output_ids
=
forward_batch_output
.
next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
...
...
@@ -2051,6 +2076,7 @@ class Scheduler(
extend_input_len_per_req
=
[
req
.
extend_input_len
for
req
in
batch
.
reqs
]
else
:
extend_input_len_per_req
=
None
if
batch
.
return_logprob
:
extend_logprob_start_len_per_req
=
[
req
.
extend_logprob_start_len
for
req
in
batch
.
reqs
...
...
@@ -2058,25 +2084,15 @@ class Scheduler(
else
:
extend_logprob_start_len_per_req
=
None
ret
=
GenerationBatchResult
(
logits_output
=
logits_output
if
self
.
pp_group
.
is_last_rank
else
None
,
pp_hidden_states_proxy_tensors
=
(
pp_hidden_states_proxy_tensors
if
not
self
.
pp_group
.
is_last_rank
else
None
),
next_token_ids
=
next_token_ids
if
self
.
pp_group
.
is_last_rank
else
None
,
return
GenerationBatchResult
.
from_forward_batch_output
(
forward_batch_output
=
forward_batch_output
,
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
bid
=
bid
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
else
:
# embedding or reward model
model_worker_batch
=
batch
.
get_model_worker_batch
()
embeddings
=
self
.
tp_worker
.
forward_batch_embedding
(
model_worker_batch
)
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
,
bid
=
model_worker_batch
.
bid
)
ret
=
EmbeddingBatchResult
(
embeddings
=
embeddings
)
return
ret
def
process_batch_result
(
...
...
python/sglang/srt/managers/scheduler_metrics_mixin.py
View file @
458611de
...
...
@@ -80,6 +80,11 @@ class SchedulerMetricsMixin:
kv_events_config
,
self
.
attn_dp_rank
)
def
udpate_spec_metrics
(
self
,
bs
:
int
,
num_accepted_tokens
:
int
):
self
.
spec_num_total_accepted_tokens
+=
num_accepted_tokens
+
bs
self
.
spec_num_total_forward_ct
+=
bs
self
.
num_generated_tokens
+=
num_accepted_tokens
def
log_prefill_stats
(
self
:
Scheduler
,
adder
:
PrefillAdder
,
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
458611de
...
...
@@ -173,8 +173,7 @@ class SchedulerOutputProcessorMixin:
self
.
set_next_batch_sampling_info_done
(
batch
)
else
:
# embedding or reward model
embeddings
,
bid
=
result
.
embeddings
,
result
.
bid
embeddings
=
embeddings
.
tolist
()
embeddings
=
result
.
embeddings
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
...
...
python/sglang/srt/managers/tp_worker.py
View file @
458611de
...
...
@@ -43,7 +43,11 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardBatch
,
ForwardBatchOutput
,
PPProxyTensors
,
)
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -234,9 +238,7 @@ class TpModelWorker:
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
bool
]:
)
->
ForwardBatchOutput
:
# update the consumer index of hicache to the running batch
self
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
...
...
@@ -271,13 +273,20 @@ class TpModelWorker:
else
:
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
forward_batch
)
return
logits_output
,
next_token_ids
,
can_run_cuda_graph
return
ForwardBatchOutput
(
logits_output
=
logits_output
,
next_token_ids
=
next_token_ids
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
else
:
pp_proxy_tensors
,
can_run_cuda_graph
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
return
pp_proxy_tensors
.
tensors
,
None
,
can_run_cuda_graph
return
ForwardBatchOutput
(
pp_proxy_tensors
=
pp_proxy_tensors
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
458611de
...
...
@@ -39,6 +39,7 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.managers.overlap_utils
import
FutureMap
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatchOutput
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
DynamicGradMode
from
sglang.utils
import
get_exception_traceback
...
...
@@ -160,13 +161,17 @@ class TpModelWorkerClient:
self
.
future_map
.
resolve_future
(
model_worker_batch
)
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
worker
.
forward_batch_generation
(
forward_batch_output
=
self
.
worker
.
forward_batch_generation
(
model_worker_batch
,
model_worker_batch
.
launch_done
,
# Skip sampling for prefill-only requests
skip_sample
=
model_worker_batch
.
is_prefill_only
,
)
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
forward_batch_output
.
logits_output
,
forward_batch_output
.
next_token_ids
,
forward_batch_output
.
can_run_cuda_graph
,
)
# Update the future token ids map
...
...
@@ -227,7 +232,7 @@ class TpModelWorkerClient:
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
)
->
Tuple
[
None
,
torch
.
Tensor
,
bool
]
:
)
->
ForwardBatchOutput
:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info
=
model_worker_batch
.
sampling_info
sampling_info
.
update_penalties
()
...
...
@@ -250,7 +255,10 @@ class TpModelWorkerClient:
future_next_token_ids
=
self
.
future_map
.
update_next_future
(
cur_future_map_ct
,
bs
)
return
None
,
future_next_token_ids
,
False
return
ForwardBatchOutput
(
next_token_ids
=
future_next_token_ids
,
can_run_cuda_graph
=
False
,
)
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
worker
.
update_weights_from_disk
(
recv_req
)
...
...
python/sglang/srt/managers/utils.py
View file @
458611de
...
...
@@ -2,11 +2,10 @@ from __future__ import annotations
import
logging
import
multiprocessing
as
mp
from
http
import
HTTPStatus
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.model_executor.forward_batch_info
import
PPProxyTensors
if
TYPE_CHECKING
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
458611de
...
...
@@ -900,6 +900,17 @@ class ForwardBatch:
return
self
.
tbo_split_seq_index
is
not
None
@
dataclass
class
ForwardBatchOutput
:
# FIXME(lsyin): unify the forward batch output between different spec and parallelism
# need to be more organized
logits_output
:
Optional
[
torch
.
Tensor
]
=
None
next_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
num_accepted_tokens
:
Optional
[
int
]
=
None
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
can_run_cuda_graph
:
bool
=
False
def
enable_num_token_non_padded
(
server_args
):
return
get_moe_expert_parallel_world_size
()
>
1
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
458611de
...
...
@@ -14,7 +14,6 @@ from sglang.srt.distributed import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.mm_utils
import
embed_mm_inputs
from
sglang.srt.managers.schedule_batch
import
(
ScheduleBatch
,
get_last_loc
,
...
...
@@ -24,6 +23,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardBatchOutput
,
ForwardMode
,
)
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -422,9 +422,7 @@ class EAGLEWorker(TpModelWorker):
def
draft_model_runner
(
self
):
return
self
.
model_runner
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
,
torch
.
Tensor
,
int
,
int
,
bool
]:
def
forward_batch_generation
(
self
,
batch
:
ScheduleBatch
)
->
ForwardBatchOutput
:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
...
...
@@ -437,14 +435,19 @@ class EAGLEWorker(TpModelWorker):
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if
batch
.
forward_mode
.
is_extend
()
or
batch
.
is_extend_in_batch
:
logits_output
,
next_token_ids
,
bid
,
seq_lens_cpu
=
(
self
.
forward_target_extend
(
batch
)
logits_output
,
next_token_ids
,
seq_lens_cpu
=
self
.
forward_target_extend
(
batch
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
,
seq_lens_cpu
)
return
logits_output
,
next_token_ids
,
bid
,
0
,
False
return
ForwardBatchOutput
(
logits_output
=
logits_output
,
next_token_ids
=
next_token_ids
,
num_accepted_tokens
=
0
,
can_run_cuda_graph
=
False
,
)
else
:
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
spec_info
=
self
.
draft
(
batch
)
...
...
@@ -462,12 +465,11 @@ class EAGLEWorker(TpModelWorker):
# decode is not finished
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
logits_output
,
verify_output
.
verified_id
,
model_worker_batch
.
bid
,
sum
(
verify_output
.
accept_length_per_req_cpu
),
can_run_cuda_graph
,
return
ForwardBatchOutput
(
logits_output
=
logits_output
,
next_token_ids
=
verify_output
.
verified_id
,
num_accepted_tokens
=
sum
(
verify_output
.
accept_length_per_req_cpu
),
can_run_cuda_graph
=
can_run_cuda_graph
,
)
def
check_forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
...
...
@@ -499,19 +501,21 @@ class EAGLEWorker(TpModelWorker):
Returns:
logits_output: The output of logits. It will contain the full hidden states.
next_token_ids: Next token ids generated.
bid: The model batch ID. Used for overlap schedule.
"""
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
logits_output
,
next_token_ids
,
_
=
self
.
target_worker
.
forward_batch_generation
(
forward_batch_output
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
logits_output
,
next_token_ids
=
(
forward_batch_output
.
logits_output
,
forward_batch_output
.
next_token_ids
,
)
return
(
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
,
model_worker_batch
.
seq_lens_cpu
,
)
...
...
@@ -811,10 +815,12 @@ class EAGLEWorker(TpModelWorker):
).
cpu
()
# Forward
logits_output
,
_
,
can_run_cuda_graph
=
(
self
.
target_worker
.
forward_batch_generation
(
forward_batch_output
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
logits_output
,
can_run_cuda_graph
=
(
forward_batch_output
.
logits_output
,
forward_batch_output
.
can_run_cuda_graph
,
)
vocab_mask
=
None
...
...
python/sglang/srt/speculative/ngram_worker.py
View file @
458611de
...
...
@@ -7,7 +7,7 @@ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatchOutput
,
ForwardMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.cpp_ngram.ngram_cache
import
NgramCache
from
sglang.srt.speculative.ngram_utils
import
NgramVerifyInput
...
...
@@ -207,17 +207,18 @@ class NGRAMWorker:
batch_tokens
.
append
(
put_ids
)
self
.
ngram_cache
.
batch_put
(
batch_tokens
)
def
forward_batch_
speculative_
generation
(
self
,
batch
:
ScheduleBatch
):
def
forward_batch_generation
(
self
,
batch
:
ScheduleBatch
)
->
ForwardBatchOutput
:
self
.
_prepare_for_speculative_decoding
(
batch
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
bid
=
model_worker_batch
.
bid
num_accepted_tokens
=
0
if
model_worker_batch
.
forward_mode
.
is_target_verify
():
logits_output
,
_
,
can_run_cuda_graph
=
(
self
.
target_worker
.
forward_batch_generation
(
forward_batch_output
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
logits_output
,
can_run_cuda_graph
=
(
forward_batch_output
.
logits_output
,
forward_batch_output
.
can_run_cuda_graph
,
)
verify_input
=
model_worker_batch
.
spec_info
logits_output
,
next_token_ids
,
num_accepted_tokens
=
verify_input
.
verify
(
...
...
@@ -227,14 +228,18 @@ class NGRAMWorker:
batch
.
forward_mode
=
ForwardMode
.
DECODE
else
:
forward_batch_output
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
forward_batch_output
.
logits_output
,
forward_batch_output
.
next_token_ids
,
forward_batch_output
.
can_run_cuda_graph
,
)
return
(
logits_output
,
next_token_ids
,
bid
,
num_accepted_tokens
,
can_run_cuda_graph
,
return
ForwardBatchOutput
(
logits_output
=
logits_output
,
next_token_ids
=
next_token_ids
,
num_accepted_tokens
=
num_accepted_tokens
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment