Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e35a93fa
Unverified
Commit
e35a93fa
authored
Mar 12, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 12, 2025
Browse files
Move output processing logic from scheduler.py into a separate file (#4354)
parent
2c3656f2
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
634 additions
and
609 deletions
+634
-609
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-22
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-576
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+602
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+15
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+12
-6
No files found.
python/sglang/srt/layers/sampler.py
View file @
e35a93fa
import
logging
import
logging
from
typing
import
List
,
Optional
from
typing
import
List
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e35a93fa
...
@@ -441,28 +441,6 @@ class Req:
...
@@ -441,28 +441,6 @@ class Req:
all_ids
=
self
.
origin_input_ids_unpadded
+
self
.
output_ids
all_ids
=
self
.
origin_input_ids_unpadded
+
self
.
output_ids
return
all_ids
[
self
.
surr_offset
:],
self
.
read_offset
-
self
.
surr_offset
return
all_ids
[
self
.
surr_offset
:],
self
.
read_offset
-
self
.
surr_offset
def
get_next_inc_detokenization
(
self
):
if
self
.
tokenizer
is
None
:
return
False
,
""
read_ids
,
read_offset
=
self
.
init_incremental_detokenize
()
surr_ids
=
read_ids
[:
read_offset
]
surr_text
=
self
.
tokenizer
.
decode
(
surr_ids
,
skip_special_tokens
=
self
.
sampling_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
sampling_params
.
spaces_between_special_tokens
,
)
new_text
=
self
.
tokenizer
.
decode
(
read_ids
,
skip_special_tokens
=
self
.
sampling_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
sampling_params
.
spaces_between_special_tokens
,
)
if
len
(
new_text
)
>
len
(
surr_text
)
and
not
new_text
.
endswith
(
"�"
):
return
True
,
new_text
[
len
(
surr_text
)
:]
return
False
,
""
def
check_finished
(
self
):
def
check_finished
(
self
):
if
self
.
finished
():
if
self
.
finished
():
return
return
...
...
python/sglang/srt/managers/scheduler.py
View file @
e35a93fa
...
@@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
...
@@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
CloseSessionReqInput
,
CloseSessionReqInput
,
FlushCacheReq
,
FlushCacheReq
,
GetInternalStateReq
,
GetInternalStateReq
,
...
@@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import (
...
@@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
FINISH_ABORT
,
BaseFinishReason
,
ImageInputs
,
ImageInputs
,
Req
,
Req
,
ScheduleBatch
,
ScheduleBatch
,
...
@@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import (
...
@@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder
,
PrefillAdder
,
SchedulePolicy
,
SchedulePolicy
,
)
)
from
sglang.srt.managers.scheduler_output_processor_mixin
import
(
SchedulerOutputProcessorMixin
,
)
from
sglang.srt.managers.session_controller
import
Session
from
sglang.srt.managers.session_controller
import
Session
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker_overlap_thread
import
TpModelWorkerClient
from
sglang.srt.managers.tp_worker_overlap_thread
import
TpModelWorkerClient
...
@@ -132,7 +132,7 @@ class EmbeddingBatchResult:
...
@@ -132,7 +132,7 @@ class EmbeddingBatchResult:
bid
:
int
bid
:
int
class
Scheduler
:
class
Scheduler
(
SchedulerOutputProcessorMixin
)
:
"""A scheduler that manages a tensor parallel GPU worker."""
"""A scheduler that manages a tensor parallel GPU worker."""
def
__init__
(
def
__init__
(
...
@@ -1256,578 +1256,6 @@ class Scheduler:
...
@@ -1256,578 +1256,6 @@ class Scheduler:
self
.
return_health_check_ct
-=
1
self
.
return_health_check_ct
-=
1
self
.
send_to_tokenizer
.
send_pyobj
(
HealthCheckOutput
())
self
.
send_to_tokenizer
.
send_pyobj
(
HealthCheckOutput
())
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
):
skip_stream_req
=
None
if
self
.
is_generation
:
(
logits_output
,
next_token_ids
,
extend_input_len_per_req
,
extend_logprob_start_len_per_req
,
bid
,
)
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
extend_input_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
bid
,
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
else
:
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
.
tolist
()
)
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
.
tolist
()
)
hidden_state_offset
=
0
# Check finish conditions
logprob_pt
=
0
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
req
.
is_retracted
:
continue
if
self
.
is_mixed_chunk
and
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token for the mixed decode batch
j
=
len
(
batch
.
out_cache_loc
)
-
len
(
batch
.
reqs
)
+
i
self
.
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
j
:
j
+
1
])
continue
if
req
.
is_chunked
<=
0
:
# req output_ids are set here
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
# This updates radix so others can match
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
.
return_logprob
:
assert
extend_logprob_start_len_per_req
is
not
None
assert
extend_input_len_per_req
is
not
None
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
num_input_logprobs
=
extend_input_len
-
extend_logprob_start_len
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
num_input_logprobs
,
logits_output
,
)
logprob_pt
+=
num_input_logprobs
if
(
req
.
return_hidden_states
and
logits_output
.
hidden_states
is
not
None
):
req
.
hidden_states
.
append
(
logits_output
.
hidden_states
[
hidden_state_offset
:
(
hidden_state_offset
:
=
hidden_state_offset
+
len
(
req
.
origin_input_ids
)
)
]
.
cpu
()
.
clone
()
)
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
else
:
# being chunked reqs' prefill is not finished
req
.
is_chunked
-=
1
# There is only at most one request being currently chunked.
# Because this request does not finish prefill,
# we don't want to stream the request currently being chunked.
skip_stream_req
=
req
# Incrementally update input logprobs.
if
req
.
return_logprob
:
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
if
extend_logprob_start_len
<
extend_input_len
:
# Update input logprobs.
num_input_logprobs
=
(
extend_input_len
-
extend_logprob_start_len
)
self
.
add_input_logprob_return_values
(
i
,
req
,
logits_output
,
logprob_pt
,
num_input_logprobs
,
last_prefill_chunk
=
False
,
)
logprob_pt
+=
num_input_logprobs
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
embeddings
,
bid
=
result
.
embeddings
,
result
.
bid
embeddings
=
embeddings
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
is_retracted
:
continue
req
.
embedding
=
embeddings
[
i
]
if
req
.
is_chunked
<=
0
:
# Dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
else
:
# being chunked reqs' prefill is not finished
req
.
is_chunked
-=
1
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
):
logits_output
,
next_token_ids
,
bid
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
bid
,
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
assert
batch
.
spec_algorithm
.
is_none
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
elif
batch
.
spec_algorithm
.
is_none
():
# spec decoding handles output logprobs inside verify process.
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
self
.
token_to_kv_pool_allocator
.
free_group_begin
()
# Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
req
.
is_retracted
:
continue
if
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token
self
.
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
if
batch
.
spec_algorithm
.
is_none
():
# speculative worker will solve the output_ids in speculative decoding
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
and
batch
.
spec_algorithm
.
is_none
():
# speculative worker handles logprob in speculative decoding
req
.
output_token_logprobs_val
.
append
(
next_token_logprobs
[
i
])
req
.
output_token_logprobs_idx
.
append
(
next_token_id
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs_val
.
append
(
logits_output
.
next_token_top_logprobs_val
[
i
]
)
req
.
output_top_logprobs_idx
.
append
(
logits_output
.
next_token_top_logprobs_idx
[
i
]
)
if
req
.
token_ids_logprob
is
not
None
:
req
.
output_token_ids_logprobs_val
.
append
(
logits_output
.
next_token_token_ids_logprobs_val
[
i
]
)
req
.
output_token_ids_logprobs_idx
.
append
(
logits_output
.
next_token_token_ids_logprobs_idx
[
i
]
)
if
req
.
return_hidden_states
and
logits_output
.
hidden_states
is
not
None
:
req
.
hidden_states
.
append
(
logits_output
.
hidden_states
[
i
].
cpu
().
clone
())
if
req
.
grammar
is
not
None
and
batch
.
spec_algorithm
.
is_none
():
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
)
self
.
token_to_kv_pool_allocator
.
free_group_end
()
self
.
forward_ct_decode
=
(
self
.
forward_ct_decode
+
1
)
%
(
1
<<
30
)
if
(
self
.
attn_tp_rank
==
0
and
self
.
forward_ct_decode
%
self
.
server_args
.
decode_log_interval
==
0
):
self
.
log_decode_stats
()
def
add_input_logprob_return_values
(
self
,
i
:
int
,
req
:
Req
,
output
:
LogitsProcessorOutput
,
logprob_pt
:
int
,
num_input_logprobs
:
int
,
last_prefill_chunk
:
bool
,
# If True, it means prefill is finished.
):
"""Incrementally add input logprobs to `req`.
Args:
i: The request index in a batch.
req: The request. Input logprobs inside req are modified as a
consequence of the API
fill_ids: The prefill ids processed.
output: Logit processor output that's used to compute input logprobs
last_prefill_chunk: True if it is the last prefill (when chunked).
Some of input logprob operation should only happen at the last
prefill (e.g., computing input token logprobs).
"""
assert
output
.
input_token_logprobs
is
not
None
if
req
.
input_token_logprobs
is
None
:
req
.
input_token_logprobs
=
[]
if
req
.
temp_input_top_logprobs_val
is
None
:
req
.
temp_input_top_logprobs_val
=
[]
if
req
.
temp_input_top_logprobs_idx
is
None
:
req
.
temp_input_top_logprobs_idx
=
[]
if
req
.
temp_input_token_ids_logprobs_val
is
None
:
req
.
temp_input_token_ids_logprobs_val
=
[]
if
req
.
temp_input_token_ids_logprobs_idx
is
None
:
req
.
temp_input_token_ids_logprobs_idx
=
[]
if
req
.
input_token_logprobs_val
is
not
None
:
# The input logprob has been already computed. It only happens
# upon retract.
if
req
.
top_logprobs_num
>
0
:
assert
req
.
input_token_logprobs_val
is
not
None
return
# Important for the performance.
assert
isinstance
(
output
.
input_token_logprobs
,
tuple
)
input_token_logprobs
:
Tuple
[
int
]
=
output
.
input_token_logprobs
input_token_logprobs
=
input_token_logprobs
[
logprob_pt
:
logprob_pt
+
num_input_logprobs
]
req
.
input_token_logprobs
.
extend
(
input_token_logprobs
)
if
req
.
top_logprobs_num
>
0
:
req
.
temp_input_top_logprobs_val
.
append
(
output
.
input_top_logprobs_val
[
i
])
req
.
temp_input_top_logprobs_idx
.
append
(
output
.
input_top_logprobs_idx
[
i
])
if
req
.
token_ids_logprob
is
not
None
:
req
.
temp_input_token_ids_logprobs_val
.
append
(
output
.
input_token_ids_logprobs_val
[
i
]
)
req
.
temp_input_token_ids_logprobs_idx
.
append
(
output
.
input_token_ids_logprobs_idx
[
i
]
)
if
last_prefill_chunk
:
input_token_logprobs
=
req
.
input_token_logprobs
req
.
input_token_logprobs
=
None
assert
req
.
input_token_logprobs_val
is
None
assert
req
.
input_token_logprobs_idx
is
None
assert
req
.
input_top_logprobs_val
is
None
assert
req
.
input_top_logprobs_idx
is
None
# Compute input_token_logprobs_val
# Always pad the first one with None.
req
.
input_token_logprobs_val
=
[
None
]
req
.
input_token_logprobs_val
.
extend
(
input_token_logprobs
)
# The last input logprob is for sampling, so just pop it out.
req
.
input_token_logprobs_val
.
pop
()
# Compute input_token_logprobs_idx
input_token_logprobs_idx
=
req
.
origin_input_ids
[
req
.
logprob_start_len
:]
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
input_token_logprobs_idx
=
[
x
if
x
<
self
.
model_config
.
vocab_size
-
1
else
0
for
x
in
input_token_logprobs_idx
]
req
.
input_token_logprobs_idx
=
input_token_logprobs_idx
if
req
.
top_logprobs_num
>
0
:
req
.
input_top_logprobs_val
=
[
None
]
req
.
input_top_logprobs_idx
=
[
None
]
assert
len
(
req
.
temp_input_token_ids_logprobs_val
)
==
len
(
req
.
temp_input_token_ids_logprobs_idx
)
for
val
,
idx
in
zip
(
req
.
temp_input_top_logprobs_val
,
req
.
temp_input_top_logprobs_idx
,
strict
=
True
,
):
req
.
input_top_logprobs_val
.
extend
(
val
)
req
.
input_top_logprobs_idx
.
extend
(
idx
)
# Last token is a sample token.
req
.
input_top_logprobs_val
.
pop
()
req
.
input_top_logprobs_idx
.
pop
()
req
.
temp_input_top_logprobs_idx
=
None
req
.
temp_input_top_logprobs_val
=
None
if
req
.
token_ids_logprob
is
not
None
:
req
.
input_token_ids_logprobs_val
=
[
None
]
req
.
input_token_ids_logprobs_idx
=
[
None
]
for
val
,
idx
in
zip
(
req
.
temp_input_token_ids_logprobs_val
,
req
.
temp_input_token_ids_logprobs_idx
,
strict
=
True
,
):
req
.
input_token_ids_logprobs_val
.
extend
(
val
)
req
.
input_token_ids_logprobs_idx
.
extend
(
idx
)
# Last token is a sample token.
req
.
input_token_ids_logprobs_val
.
pop
()
req
.
input_token_ids_logprobs_idx
.
pop
()
req
.
temp_input_token_ids_logprobs_idx
=
None
req
.
temp_input_token_ids_logprobs_val
=
None
if
req
.
return_logprob
:
relevant_tokens_len
=
len
(
req
.
origin_input_ids
)
-
req
.
logprob_start_len
assert
len
(
req
.
input_token_logprobs_val
)
==
relevant_tokens_len
assert
len
(
req
.
input_token_logprobs_idx
)
==
relevant_tokens_len
if
req
.
top_logprobs_num
>
0
:
assert
len
(
req
.
input_top_logprobs_val
)
==
relevant_tokens_len
assert
len
(
req
.
input_top_logprobs_idx
)
==
relevant_tokens_len
if
req
.
token_ids_logprob
is
not
None
:
assert
len
(
req
.
input_token_ids_logprobs_val
)
==
relevant_tokens_len
assert
len
(
req
.
input_token_ids_logprobs_idx
)
==
relevant_tokens_len
def
add_logprob_return_values
(
self
,
i
:
int
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
num_input_logprobs
:
int
,
output
:
LogitsProcessorOutput
,
):
"""Attach logprobs to the return values."""
req
.
output_token_logprobs_val
.
append
(
output
.
next_token_logprobs
[
i
])
req
.
output_token_logprobs_idx
.
append
(
next_token_ids
[
i
])
self
.
add_input_logprob_return_values
(
i
,
req
,
output
,
pt
,
num_input_logprobs
,
last_prefill_chunk
=
True
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs_val
.
append
(
output
.
next_token_top_logprobs_val
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
next_token_top_logprobs_idx
[
i
])
if
req
.
token_ids_logprob
is
not
None
:
req
.
output_token_ids_logprobs_val
.
append
(
output
.
next_token_token_ids_logprobs_val
[
i
]
)
req
.
output_token_ids_logprobs_idx
.
append
(
output
.
next_token_token_ids_logprobs_idx
[
i
]
)
return
num_input_logprobs
def
stream_output
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
):
"""Stream the output to detokenizer."""
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
if
self
.
is_generation
:
decoded_texts
=
[]
decode_ids_list
=
[]
read_offsets
=
[]
output_ids
=
[]
skip_special_tokens
=
[]
spaces_between_special_tokens
=
[]
no_stop_trim
=
[]
prompt_tokens
=
[]
completion_tokens
=
[]
cached_tokens
=
[]
spec_verify_ct
=
[]
output_hidden_states
=
None
if
return_logprob
:
input_token_logprobs_val
=
[]
input_token_logprobs_idx
=
[]
output_token_logprobs_val
=
[]
output_token_logprobs_idx
=
[]
input_top_logprobs_val
=
[]
input_top_logprobs_idx
=
[]
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
input_token_ids_logprobs_val
=
[]
input_token_ids_logprobs_idx
=
[]
output_token_ids_logprobs_val
=
[]
output_token_ids_logprobs_idx
=
[]
else
:
input_token_logprobs_val
=
input_token_logprobs_idx
=
(
output_token_logprobs_val
)
=
output_token_logprobs_idx
=
input_top_logprobs_val
=
(
input_top_logprobs_idx
)
=
output_top_logprobs_val
=
output_top_logprobs_idx
=
(
input_token_ids_logprobs_val
)
=
input_token_ids_logprobs_idx
=
output_token_ids_logprobs_val
=
(
output_token_ids_logprobs_idx
)
=
None
for
req
in
reqs
:
if
req
is
skip_req
:
continue
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
if
self
.
model_config
.
is_multimodal_gen
and
req
.
to_abort
:
continue
if
(
req
.
finished
()
# If stream, follow the given stream_interval
or
(
req
.
stream
and
len
(
req
.
output_ids
)
%
self
.
stream_interval
==
0
)
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
# always increase one-by-one.
or
(
not
req
.
stream
and
len
(
req
.
output_ids
)
%
50
==
0
and
not
self
.
model_config
.
is_multimodal_gen
)
):
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
)
decoded_texts
.
append
(
req
.
decoded_text
)
decode_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
decode_ids_list
.
append
(
decode_ids
)
read_offsets
.
append
(
read_offset
)
if
self
.
skip_tokenizer_init
:
output_ids
.
append
(
req
.
output_ids
)
skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
completion_tokens
.
append
(
len
(
req
.
output_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
if
not
self
.
spec_algorithm
.
is_none
():
spec_verify_ct
.
append
(
req
.
spec_verify_ct
)
if
return_logprob
:
input_token_logprobs_val
.
append
(
req
.
input_token_logprobs_val
)
input_token_logprobs_idx
.
append
(
req
.
input_token_logprobs_idx
)
output_token_logprobs_val
.
append
(
req
.
output_token_logprobs_val
)
output_token_logprobs_idx
.
append
(
req
.
output_token_logprobs_idx
)
input_top_logprobs_val
.
append
(
req
.
input_top_logprobs_val
)
input_top_logprobs_idx
.
append
(
req
.
input_top_logprobs_idx
)
output_top_logprobs_val
.
append
(
req
.
output_top_logprobs_val
)
output_top_logprobs_idx
.
append
(
req
.
output_top_logprobs_idx
)
input_token_ids_logprobs_val
.
append
(
req
.
input_token_ids_logprobs_val
)
input_token_ids_logprobs_idx
.
append
(
req
.
input_token_ids_logprobs_idx
)
output_token_ids_logprobs_val
.
append
(
req
.
output_token_ids_logprobs_val
)
output_token_ids_logprobs_idx
.
append
(
req
.
output_token_ids_logprobs_idx
)
if
req
.
return_hidden_states
:
if
output_hidden_states
is
None
:
output_hidden_states
=
[]
output_hidden_states
.
append
(
req
.
hidden_states
)
# Send to detokenizer
if
rids
:
if
self
.
model_config
.
is_multimodal_gen
:
raise
NotImplementedError
()
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
rids
,
finished_reasons
,
decoded_texts
,
decode_ids_list
,
read_offsets
,
output_ids
,
skip_special_tokens
,
spaces_between_special_tokens
,
no_stop_trim
,
prompt_tokens
,
completion_tokens
,
cached_tokens
,
spec_verify_ct
,
input_token_logprobs_val
,
input_token_logprobs_idx
,
output_token_logprobs_val
,
output_token_logprobs_idx
,
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
input_token_ids_logprobs_val
,
input_token_ids_logprobs_idx
,
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
,
output_hidden_states
,
)
)
else
:
# embedding or reward model
embeddings
=
[]
prompt_tokens
=
[]
cached_tokens
=
[]
for
req
in
reqs
:
if
req
.
finished
():
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
())
embeddings
.
append
(
req
.
embedding
)
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
rids
,
finished_reasons
,
embeddings
,
prompt_tokens
,
cached_tokens
)
)
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
# Check if other DP workers have running batches
# Check if other DP workers have running batches
if
local_batch
is
None
:
if
local_batch
is
None
:
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
0 → 100644
View file @
e35a93fa
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
BatchEmbeddingOut
,
BatchTokenIDOut
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
,
Req
,
ScheduleBatch
if
TYPE_CHECKING
:
from
sglang.srt.managers.scheduler
import
(
EmbeddingBatchResult
,
GenerationBatchResult
,
ScheduleBatch
,
)
class
SchedulerOutputProcessorMixin
:
"""
This class implements the output processing logic for Scheduler.
We put them into a separate file to make the `scheduler.py` shorter.
"""
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
):
skip_stream_req
=
None
if
self
.
is_generation
:
(
logits_output
,
next_token_ids
,
extend_input_len_per_req
,
extend_logprob_start_len_per_req
,
bid
,
)
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
extend_input_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
bid
,
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
else
:
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
.
tolist
()
)
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
.
tolist
()
)
hidden_state_offset
=
0
# Check finish conditions
logprob_pt
=
0
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
req
.
is_retracted
:
continue
if
self
.
is_mixed_chunk
and
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token for the mixed decode batch
j
=
len
(
batch
.
out_cache_loc
)
-
len
(
batch
.
reqs
)
+
i
self
.
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
j
:
j
+
1
])
continue
if
req
.
is_chunked
<=
0
:
# req output_ids are set here
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
not
batch
.
decoding_reqs
or
req
not
in
batch
.
decoding_reqs
:
# This updates radix so others can match
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
.
return_logprob
:
assert
extend_logprob_start_len_per_req
is
not
None
assert
extend_input_len_per_req
is
not
None
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
num_input_logprobs
=
extend_input_len
-
extend_logprob_start_len
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
num_input_logprobs
,
logits_output
,
)
logprob_pt
+=
num_input_logprobs
if
(
req
.
return_hidden_states
and
logits_output
.
hidden_states
is
not
None
):
req
.
hidden_states
.
append
(
logits_output
.
hidden_states
[
hidden_state_offset
:
(
hidden_state_offset
:
=
hidden_state_offset
+
len
(
req
.
origin_input_ids
)
)
]
.
cpu
()
.
clone
()
)
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
else
:
# being chunked reqs' prefill is not finished
req
.
is_chunked
-=
1
# There is only at most one request being currently chunked.
# Because this request does not finish prefill,
# we don't want to stream the request currently being chunked.
skip_stream_req
=
req
# Incrementally update input logprobs.
if
req
.
return_logprob
:
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
if
extend_logprob_start_len
<
extend_input_len
:
# Update input logprobs.
num_input_logprobs
=
(
extend_input_len
-
extend_logprob_start_len
)
self
.
add_input_logprob_return_values
(
i
,
req
,
logits_output
,
logprob_pt
,
num_input_logprobs
,
last_prefill_chunk
=
False
,
)
logprob_pt
+=
num_input_logprobs
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
embeddings
,
bid
=
result
.
embeddings
,
result
.
bid
embeddings
=
embeddings
.
tolist
()
# Check finish conditions
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
is_retracted
:
continue
req
.
embedding
=
embeddings
[
i
]
if
req
.
is_chunked
<=
0
:
# Dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
else
:
self
.
tree_cache
.
cache_unfinished_req
(
req
)
else
:
# being chunked reqs' prefill is not finished
req
.
is_chunked
-=
1
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
):
logits_output
,
next_token_ids
,
bid
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
bid
,
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
elif
batch
.
spec_algorithm
.
is_none
():
# spec decoding handles output logprobs inside verify process.
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
self
.
token_to_kv_pool_allocator
.
free_group_begin
()
# Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
req
.
is_retracted
:
continue
if
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token
self
.
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
if
batch
.
spec_algorithm
.
is_none
():
# speculative worker will solve the output_ids in speculative decoding
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
and
batch
.
spec_algorithm
.
is_none
():
# speculative worker handles logprob in speculative decoding
req
.
output_token_logprobs_val
.
append
(
next_token_logprobs
[
i
])
req
.
output_token_logprobs_idx
.
append
(
next_token_id
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs_val
.
append
(
logits_output
.
next_token_top_logprobs_val
[
i
]
)
req
.
output_top_logprobs_idx
.
append
(
logits_output
.
next_token_top_logprobs_idx
[
i
]
)
if
req
.
token_ids_logprob
is
not
None
:
req
.
output_token_ids_logprobs_val
.
append
(
logits_output
.
next_token_token_ids_logprobs_val
[
i
]
)
req
.
output_token_ids_logprobs_idx
.
append
(
logits_output
.
next_token_token_ids_logprobs_idx
[
i
]
)
if
req
.
return_hidden_states
and
logits_output
.
hidden_states
is
not
None
:
req
.
hidden_states
.
append
(
logits_output
.
hidden_states
[
i
].
cpu
().
clone
())
if
req
.
grammar
is
not
None
and
batch
.
spec_algorithm
.
is_none
():
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
)
self
.
token_to_kv_pool_allocator
.
free_group_end
()
self
.
forward_ct_decode
=
(
self
.
forward_ct_decode
+
1
)
%
(
1
<<
30
)
if
(
self
.
attn_tp_rank
==
0
and
self
.
forward_ct_decode
%
self
.
server_args
.
decode_log_interval
==
0
):
self
.
log_decode_stats
()
def
add_input_logprob_return_values
(
self
,
i
:
int
,
req
:
Req
,
output
:
LogitsProcessorOutput
,
logprob_pt
:
int
,
num_input_logprobs
:
int
,
last_prefill_chunk
:
bool
,
# If True, it means prefill is finished.
):
"""Incrementally add input logprobs to `req`.
Args:
i: The request index in a batch.
req: The request. Input logprobs inside req are modified as a
consequence of the API
fill_ids: The prefill ids processed.
output: Logit processor output that's used to compute input logprobs
last_prefill_chunk: True if it is the last prefill (when chunked).
Some of input logprob operation should only happen at the last
prefill (e.g., computing input token logprobs).
"""
assert
output
.
input_token_logprobs
is
not
None
if
req
.
input_token_logprobs
is
None
:
req
.
input_token_logprobs
=
[]
if
req
.
temp_input_top_logprobs_val
is
None
:
req
.
temp_input_top_logprobs_val
=
[]
if
req
.
temp_input_top_logprobs_idx
is
None
:
req
.
temp_input_top_logprobs_idx
=
[]
if
req
.
temp_input_token_ids_logprobs_val
is
None
:
req
.
temp_input_token_ids_logprobs_val
=
[]
if
req
.
temp_input_token_ids_logprobs_idx
is
None
:
req
.
temp_input_token_ids_logprobs_idx
=
[]
if
req
.
input_token_logprobs_val
is
not
None
:
# The input logprob has been already computed. It only happens
# upon retract.
if
req
.
top_logprobs_num
>
0
:
assert
req
.
input_token_logprobs_val
is
not
None
return
# Important for the performance.
assert
isinstance
(
output
.
input_token_logprobs
,
tuple
)
input_token_logprobs
:
Tuple
[
int
]
=
output
.
input_token_logprobs
input_token_logprobs
=
input_token_logprobs
[
logprob_pt
:
logprob_pt
+
num_input_logprobs
]
req
.
input_token_logprobs
.
extend
(
input_token_logprobs
)
if
req
.
top_logprobs_num
>
0
:
req
.
temp_input_top_logprobs_val
.
append
(
output
.
input_top_logprobs_val
[
i
])
req
.
temp_input_top_logprobs_idx
.
append
(
output
.
input_top_logprobs_idx
[
i
])
if
req
.
token_ids_logprob
is
not
None
:
req
.
temp_input_token_ids_logprobs_val
.
append
(
output
.
input_token_ids_logprobs_val
[
i
]
)
req
.
temp_input_token_ids_logprobs_idx
.
append
(
output
.
input_token_ids_logprobs_idx
[
i
]
)
if
last_prefill_chunk
:
input_token_logprobs
=
req
.
input_token_logprobs
req
.
input_token_logprobs
=
None
assert
req
.
input_token_logprobs_val
is
None
assert
req
.
input_token_logprobs_idx
is
None
assert
req
.
input_top_logprobs_val
is
None
assert
req
.
input_top_logprobs_idx
is
None
# Compute input_token_logprobs_val
# Always pad the first one with None.
req
.
input_token_logprobs_val
=
[
None
]
req
.
input_token_logprobs_val
.
extend
(
input_token_logprobs
)
# The last input logprob is for sampling, so just pop it out.
req
.
input_token_logprobs_val
.
pop
()
# Compute input_token_logprobs_idx
input_token_logprobs_idx
=
req
.
origin_input_ids
[
req
.
logprob_start_len
:]
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
input_token_logprobs_idx
=
[
x
if
x
<
self
.
model_config
.
vocab_size
-
1
else
0
for
x
in
input_token_logprobs_idx
]
req
.
input_token_logprobs_idx
=
input_token_logprobs_idx
if
req
.
top_logprobs_num
>
0
:
req
.
input_top_logprobs_val
=
[
None
]
req
.
input_top_logprobs_idx
=
[
None
]
assert
len
(
req
.
temp_input_token_ids_logprobs_val
)
==
len
(
req
.
temp_input_token_ids_logprobs_idx
)
for
val
,
idx
in
zip
(
req
.
temp_input_top_logprobs_val
,
req
.
temp_input_top_logprobs_idx
,
strict
=
True
,
):
req
.
input_top_logprobs_val
.
extend
(
val
)
req
.
input_top_logprobs_idx
.
extend
(
idx
)
# Last token is a sample token.
req
.
input_top_logprobs_val
.
pop
()
req
.
input_top_logprobs_idx
.
pop
()
req
.
temp_input_top_logprobs_idx
=
None
req
.
temp_input_top_logprobs_val
=
None
if
req
.
token_ids_logprob
is
not
None
:
req
.
input_token_ids_logprobs_val
=
[
None
]
req
.
input_token_ids_logprobs_idx
=
[
None
]
for
val
,
idx
in
zip
(
req
.
temp_input_token_ids_logprobs_val
,
req
.
temp_input_token_ids_logprobs_idx
,
strict
=
True
,
):
req
.
input_token_ids_logprobs_val
.
extend
(
val
)
req
.
input_token_ids_logprobs_idx
.
extend
(
idx
)
# Last token is a sample token.
req
.
input_token_ids_logprobs_val
.
pop
()
req
.
input_token_ids_logprobs_idx
.
pop
()
req
.
temp_input_token_ids_logprobs_idx
=
None
req
.
temp_input_token_ids_logprobs_val
=
None
if
req
.
return_logprob
:
relevant_tokens_len
=
len
(
req
.
origin_input_ids
)
-
req
.
logprob_start_len
assert
len
(
req
.
input_token_logprobs_val
)
==
relevant_tokens_len
assert
len
(
req
.
input_token_logprobs_idx
)
==
relevant_tokens_len
if
req
.
top_logprobs_num
>
0
:
assert
len
(
req
.
input_top_logprobs_val
)
==
relevant_tokens_len
assert
len
(
req
.
input_top_logprobs_idx
)
==
relevant_tokens_len
if
req
.
token_ids_logprob
is
not
None
:
assert
len
(
req
.
input_token_ids_logprobs_val
)
==
relevant_tokens_len
assert
len
(
req
.
input_token_ids_logprobs_idx
)
==
relevant_tokens_len
def
add_logprob_return_values
(
self
,
i
:
int
,
req
:
Req
,
pt
:
int
,
next_token_ids
:
List
[
int
],
num_input_logprobs
:
int
,
output
:
LogitsProcessorOutput
,
):
"""Attach logprobs to the return values."""
req
.
output_token_logprobs_val
.
append
(
output
.
next_token_logprobs
[
i
])
req
.
output_token_logprobs_idx
.
append
(
next_token_ids
[
i
])
self
.
add_input_logprob_return_values
(
i
,
req
,
output
,
pt
,
num_input_logprobs
,
last_prefill_chunk
=
True
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs_val
.
append
(
output
.
next_token_top_logprobs_val
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
next_token_top_logprobs_idx
[
i
])
if
req
.
token_ids_logprob
is
not
None
:
req
.
output_token_ids_logprobs_val
.
append
(
output
.
next_token_token_ids_logprobs_val
[
i
]
)
req
.
output_token_ids_logprobs_idx
.
append
(
output
.
next_token_token_ids_logprobs_idx
[
i
]
)
return
num_input_logprobs
def
stream_output
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
):
"""Stream the output to detokenizer."""
if
self
.
is_generation
:
self
.
stream_output_generation
(
reqs
,
return_logprob
,
skip_req
)
else
:
# embedding or reward model
self
.
stream_output_embedding
(
reqs
)
def
stream_output_generation
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
):
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
decoded_texts
=
[]
decode_ids_list
=
[]
read_offsets
=
[]
output_ids
=
[]
skip_special_tokens
=
[]
spaces_between_special_tokens
=
[]
no_stop_trim
=
[]
prompt_tokens
=
[]
completion_tokens
=
[]
cached_tokens
=
[]
spec_verify_ct
=
[]
output_hidden_states
=
None
if
return_logprob
:
input_token_logprobs_val
=
[]
input_token_logprobs_idx
=
[]
output_token_logprobs_val
=
[]
output_token_logprobs_idx
=
[]
input_top_logprobs_val
=
[]
input_top_logprobs_idx
=
[]
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
input_token_ids_logprobs_val
=
[]
input_token_ids_logprobs_idx
=
[]
output_token_ids_logprobs_val
=
[]
output_token_ids_logprobs_idx
=
[]
else
:
input_token_logprobs_val
=
input_token_logprobs_idx
=
(
output_token_logprobs_val
)
=
output_token_logprobs_idx
=
input_top_logprobs_val
=
(
input_top_logprobs_idx
)
=
output_top_logprobs_val
=
output_top_logprobs_idx
=
(
input_token_ids_logprobs_val
)
=
input_token_ids_logprobs_idx
=
output_token_ids_logprobs_val
=
(
output_token_ids_logprobs_idx
)
=
None
for
req
in
reqs
:
if
req
is
skip_req
:
continue
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
if
self
.
model_config
.
is_multimodal_gen
and
req
.
to_abort
:
continue
if
(
req
.
finished
()
# If stream, follow the given stream_interval
or
(
req
.
stream
and
len
(
req
.
output_ids
)
%
self
.
stream_interval
==
0
)
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
# always increase one-by-one.
or
(
not
req
.
stream
and
len
(
req
.
output_ids
)
%
50
==
0
and
not
self
.
model_config
.
is_multimodal_gen
)
):
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
)
decoded_texts
.
append
(
req
.
decoded_text
)
decode_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
decode_ids_list
.
append
(
decode_ids
)
read_offsets
.
append
(
read_offset
)
if
self
.
skip_tokenizer_init
:
output_ids
.
append
(
req
.
output_ids
)
skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
completion_tokens
.
append
(
len
(
req
.
output_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
if
not
self
.
spec_algorithm
.
is_none
():
spec_verify_ct
.
append
(
req
.
spec_verify_ct
)
if
return_logprob
:
input_token_logprobs_val
.
append
(
req
.
input_token_logprobs_val
)
input_token_logprobs_idx
.
append
(
req
.
input_token_logprobs_idx
)
output_token_logprobs_val
.
append
(
req
.
output_token_logprobs_val
)
output_token_logprobs_idx
.
append
(
req
.
output_token_logprobs_idx
)
input_top_logprobs_val
.
append
(
req
.
input_top_logprobs_val
)
input_top_logprobs_idx
.
append
(
req
.
input_top_logprobs_idx
)
output_top_logprobs_val
.
append
(
req
.
output_top_logprobs_val
)
output_top_logprobs_idx
.
append
(
req
.
output_top_logprobs_idx
)
input_token_ids_logprobs_val
.
append
(
req
.
input_token_ids_logprobs_val
)
input_token_ids_logprobs_idx
.
append
(
req
.
input_token_ids_logprobs_idx
)
output_token_ids_logprobs_val
.
append
(
req
.
output_token_ids_logprobs_val
)
output_token_ids_logprobs_idx
.
append
(
req
.
output_token_ids_logprobs_idx
)
if
req
.
return_hidden_states
:
if
output_hidden_states
is
None
:
output_hidden_states
=
[]
output_hidden_states
.
append
(
req
.
hidden_states
)
# Send to detokenizer
if
rids
:
if
self
.
model_config
.
is_multimodal_gen
:
return
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
rids
,
finished_reasons
,
decoded_texts
,
decode_ids_list
,
read_offsets
,
output_ids
,
skip_special_tokens
,
spaces_between_special_tokens
,
no_stop_trim
,
prompt_tokens
,
completion_tokens
,
cached_tokens
,
spec_verify_ct
,
input_token_logprobs_val
,
input_token_logprobs_idx
,
output_token_logprobs_val
,
output_token_logprobs_idx
,
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
input_token_ids_logprobs_val
,
input_token_ids_logprobs_idx
,
output_token_ids_logprobs_val
,
output_token_ids_logprobs_idx
,
output_hidden_states
,
)
)
def
stream_output_embedding
(
self
,
reqs
:
List
[
Req
]):
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
embeddings
=
[]
prompt_tokens
=
[]
cached_tokens
=
[]
for
req
in
reqs
:
if
req
.
finished
():
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
())
embeddings
.
append
(
req
.
embedding
)
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
rids
,
finished_reasons
,
embeddings
,
prompt_tokens
,
cached_tokens
)
)
python/sglang/srt/model_executor/model_runner.py
View file @
e35a93fa
...
@@ -82,7 +82,6 @@ from sglang.srt.utils import (
...
@@ -82,7 +82,6 @@ from sglang.srt.utils import (
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
SGLANG_CI_SMALL_KV_SIZE
=
os
.
getenv
(
"SGLANG_CI_SMALL_KV_SIZE"
,
None
)
SGLANG_CI_SMALL_KV_SIZE
=
os
.
getenv
(
"SGLANG_CI_SMALL_KV_SIZE"
,
None
)
UNBALANCED_MODEL_LOADING_TIMEOUT_S
=
300
UNBALANCED_MODEL_LOADING_TIMEOUT_S
=
300
...
@@ -119,6 +118,7 @@ class ModelRunner:
...
@@ -119,6 +118,7 @@ class ModelRunner:
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
self
.
page_size
=
server_args
.
page_size
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
...
@@ -161,6 +161,11 @@ class ModelRunner:
...
@@ -161,6 +161,11 @@ class ModelRunner:
# Get memory before model loading
# Get memory before model loading
min_per_gpu_memory
=
self
.
init_torch_distributed
()
min_per_gpu_memory
=
self
.
init_torch_distributed
()
# If it is a draft model tp_group can be different.
self
.
initialize
(
min_per_gpu_memory
)
def
initialize
(
self
,
min_per_gpu_memory
:
float
):
server_args
=
self
.
server_args
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
self
.
server_args
.
enable_memory_saver
enable
=
self
.
server_args
.
enable_memory_saver
)
)
...
@@ -300,15 +305,16 @@ class ModelRunner:
...
@@ -300,15 +305,16 @@ class ModelRunner:
min_per_gpu_memory
=
get_available_gpu_memory
(
min_per_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
)
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
tp_group
=
get_tp_group
()
self
.
tp_group
=
get_tp_group
()
self
.
attention_tp_group
=
get_attention_tp_group
()
self
.
attention_tp_group
=
get_attention_tp_group
()
# Check memory for tensor parallelism
# Check memory for tensor parallelism
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
if
min_per_gpu_memory
<
local_gpu_memory
*
0.9
:
if
min_per_gpu_memory
<
local_gpu_memory
*
0.9
:
raise
ValueError
(
raise
ValueError
(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f
"
{
min_per_gpu_memory
=
}
,
{
local_gpu_memory
=
}
,
{
local_gpu_memory
*
0.9
=
}
"
)
)
logger
.
info
(
logger
.
info
(
...
@@ -698,6 +704,12 @@ class ModelRunner:
...
@@ -698,6 +704,12 @@ class ModelRunner:
)
)
self
.
max_total_num_tokens
=
min
(
self
.
max_total_num_tokens
,
max_total_tokens
)
self
.
max_total_num_tokens
=
min
(
self
.
max_total_num_tokens
,
max_total_tokens
)
self
.
max_total_num_tokens
=
(
self
.
max_total_num_tokens
//
self
.
server_args
.
page_size
*
self
.
server_args
.
page_size
)
if
self
.
max_total_num_tokens
<=
0
:
if
self
.
max_total_num_tokens
<=
0
:
raise
RuntimeError
(
raise
RuntimeError
(
"Not enough memory. Please try to increase --mem-fraction-static."
"Not enough memory. Please try to increase --mem-fraction-static."
...
@@ -783,7 +795,6 @@ class ModelRunner:
...
@@ -783,7 +795,6 @@ class ModelRunner:
# Init streams
# Init streams
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
attn_backend
=
FlashInferAttnBackend
(
self
)
self
.
attn_backend
=
FlashInferAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
assert
self
.
sliding_window_size
is
None
,
(
...
...
python/sglang/srt/server_args.py
View file @
e35a93fa
...
@@ -20,14 +20,13 @@ import random
...
@@ -20,14 +20,13 @@ import random
import
tempfile
import
tempfile
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_amdgpu_memory_capacity
,
get_amdgpu_memory_capacity
,
get_hpu_memory_capacity
,
get_hpu_memory_capacity
,
get_nvgpu_memory_capacity
,
get_nvgpu_memory_capacity
,
is_cuda
,
is_flashinfer_available
,
is_flashinfer_available
,
is_hip
,
is_hip
,
is_port_available
,
is_port_available
,
...
@@ -71,6 +70,7 @@ class ServerArgs:
...
@@ -71,6 +70,7 @@ class ServerArgs:
schedule_policy
:
str
=
"fcfs"
schedule_policy
:
str
=
"fcfs"
schedule_conservativeness
:
float
=
1.0
schedule_conservativeness
:
float
=
1.0
cpu_offload_gb
:
int
=
0
cpu_offload_gb
:
int
=
0
page_size
:
int
=
1
# Other runtime options
# Other runtime options
tp_size
:
int
=
1
tp_size
:
int
=
1
...
@@ -190,10 +190,10 @@ class ServerArgs:
...
@@ -190,10 +190,10 @@ class ServerArgs:
if
self
.
random_seed
is
None
:
if
self
.
random_seed
is
None
:
self
.
random_seed
=
random
.
randint
(
0
,
1
<<
30
)
self
.
random_seed
=
random
.
randint
(
0
,
1
<<
30
)
if
is_hip
():
if
is_cuda
():
gpu_mem
=
get_amdgpu_memory_capacity
()
elif
torch
.
cuda
.
is_available
():
gpu_mem
=
get_nvgpu_memory_capacity
()
gpu_mem
=
get_nvgpu_memory_capacity
()
elif
is_hip
():
gpu_mem
=
get_amdgpu_memory_capacity
()
elif
self
.
device
==
"hpu"
:
elif
self
.
device
==
"hpu"
:
gpu_mem
=
get_hpu_memory_capacity
()
gpu_mem
=
get_hpu_memory_capacity
()
else
:
else
:
...
@@ -258,7 +258,7 @@ class ServerArgs:
...
@@ -258,7 +258,7 @@ class ServerArgs:
f
"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
f
"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[
{
self
.
tp_size
}
]."
)
)
#
Others
#
Data parallelism attention
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
self
.
dp_size
=
self
.
tp_size
self
.
dp_size
=
self
.
tp_size
assert
self
.
tp_size
%
self
.
dp_size
==
0
assert
self
.
tp_size
%
self
.
dp_size
==
0
...
@@ -507,6 +507,12 @@ class ServerArgs:
...
@@ -507,6 +507,12 @@ class ServerArgs:
default
=
ServerArgs
.
cpu_offload_gb
,
default
=
ServerArgs
.
cpu_offload_gb
,
help
=
"How many GBs of RAM to reserve for CPU offloading."
,
help
=
"How many GBs of RAM to reserve for CPU offloading."
,
)
)
parser
.
add_argument
(
"--page-size"
,
type
=
int
,
default
=
ServerArgs
.
page_size
,
help
=
"The number of tokens in a page."
,
)
# Other runtime options
# Other runtime options
parser
.
add_argument
(
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment