Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a6ca736c
Unverified
Commit
a6ca736c
authored
Dec 08, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 08, 2024
Browse files
Simplify stream_output (#2398)
parent
f62055b5
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
425 additions
and
289 deletions
+425
-289
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+54
-40
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+35
-16
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+39
-10
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+30
-16
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+160
-137
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+98
-57
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+7
-2
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+2
-2
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+0
-9
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
a6ca736c
...
...
@@ -39,10 +39,12 @@ class LogitsProcessorOutput:
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs
:
torch
.
Tensor
=
None
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs
:
List
=
None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs
:
List
=
None
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
input_top_logprobs_val
:
List
=
None
input_top_logprobs_idx
:
List
=
None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
output_top_logprobs_val
:
List
=
None
output_top_logprobs_idx
:
List
=
None
@
dataclasses
.
dataclass
...
...
@@ -125,12 +127,15 @@ class LogitsProcessor(nn.Module):
indices
=
ret
.
indices
.
tolist
()
if
logits_metadata
.
forward_mode
.
is_decode
():
output_top_logprobs
=
[]
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
logits_metadata
.
top_logprobs_nums
):
output_top_logprobs
.
append
(
list
(
zip
(
values
[
i
][:
k
],
indices
[
i
][:
k
])))
return
None
,
output_top_logprobs
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
None
,
None
,
output_top_logprobs_val
,
output_top_logprobs_idx
else
:
input_top_logprobs
,
output_top_logprobs
=
[],
[]
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
output_top_logprobs_val
,
output_top_logprobs_idx
=
[],
[]
pt
=
0
for
k
,
pruned_len
in
zip
(
...
...
@@ -138,27 +143,36 @@ class LogitsProcessor(nn.Module):
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
):
if
pruned_len
<=
0
:
input_top_logprobs
.
append
([])
output_top_logprobs
.
append
([])
input_top_logprobs_val
.
append
([])
input_top_logprobs_idx
.
append
([])
output_top_logprobs_val
.
append
([])
output_top_logprobs_idx
.
append
([])
continue
input_top_logprobs
.
append
(
[
list
(
zip
(
values
[
pt
+
j
][:
k
],
indices
[
pt
+
j
][:
k
])
)
for
j
in
range
(
pruned_len
-
1
)
]
input_top_logprobs
_val
.
append
(
[
values
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
input_top_logprobs_idx
.
append
(
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)
]
)
output_top_logprobs
.
append
(
output_top_logprobs
_val
.
append
(
list
(
zip
(
values
[
pt
+
pruned_len
-
1
][:
k
],
indices
[
pt
+
pruned_len
-
1
][:
k
],
)
)
output_top_logprobs_idx
.
append
(
list
(
indices
[
pt
+
pruned_len
-
1
][:
k
],
)
)
pt
+=
pruned_len
return
input_top_logprobs
,
output_top_logprobs
return
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
def
forward
(
self
,
...
...
@@ -193,29 +207,22 @@ class LogitsProcessor(nn.Module):
if
not
logits_metadata
.
return_logprob
:
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
input_token_logprobs
=
None
,
input_top_logprobs
=
None
,
output_top_logprobs
=
None
,
)
else
:
last_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
if
logits_metadata
.
forward_mode
.
is_decode
():
if
logits_metadata
.
return_top_logprob
:
output_top_logprobs
=
self
.
ge
t_top_logprobs
(
last_logprobs
,
logits_metadata
)
[
1
]
output_top_logprobs
_val
,
outpu
t_top_logprobs
_idx
=
(
self
.
get_top_logprobs
(
last_logprobs
,
logits_metadata
)[
2
:
4
]
)
else
:
output_top_logprobs
=
None
output_top_logprobs
_val
=
output_top_logprobs_idx
=
None
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
None
,
input_token_logprobs
=
None
,
input_top_logprobs
=
None
,
output_top_logprobs
=
output_top_logprobs
,
output_top_logprobs_val
=
output_top_logprobs_val
,
output_top_logprobs_idx
=
output_top_logprobs_idx
,
)
else
:
# Slice the requested tokens to compute logprob
...
...
@@ -246,11 +253,16 @@ class LogitsProcessor(nn.Module):
# Get the logprob of top-k tokens
if
logits_metadata
.
return_top_logprob
:
input_top_logprobs
,
output_top_logprobs
=
self
.
get_top_logprobs
(
all_logprobs
,
logits_metadata
)
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
=
self
.
get_top_logprobs
(
all_logprobs
,
logits_metadata
)
else
:
input_top_logprobs
=
output_top_logprobs
=
None
input_top_logprobs_val
=
input_top_logprobs_idx
=
(
output_top_logprobs_val
)
=
output_top_logprobs_idx
=
None
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
...
...
@@ -273,8 +285,10 @@ class LogitsProcessor(nn.Module):
next_token_logprobs
=
last_logprobs
,
normalized_prompt_logprobs
=
normalized_prompt_logprobs
,
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs
=
input_top_logprobs
,
output_top_logprobs
=
output_top_logprobs
,
input_top_logprobs_val
=
input_top_logprobs_val
,
input_top_logprobs_idx
=
input_top_logprobs_idx
,
output_top_logprobs_val
=
output_top_logprobs_val
,
output_top_logprobs_idx
=
output_top_logprobs_idx
,
)
def
_get_logits
(
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
a6ca736c
...
...
@@ -17,7 +17,7 @@ import dataclasses
import
logging
import
signal
from
collections
import
OrderedDict
from
typing
import
List
,
Union
from
typing
import
Dict
,
List
,
Union
import
psutil
import
setproctitle
...
...
@@ -76,17 +76,25 @@ class DetokenizerManager:
self
.
decode_status
=
LimitedCapacityDict
()
def
trim_eos
(
self
,
output
:
Union
[
str
,
List
[
int
]],
finished_reason
,
no_stop_trim
):
if
no_stop_trim
:
def
trim_matched_stop
(
self
,
output
:
Union
[
str
,
List
[
int
]],
finished_reason
:
Dict
,
no_stop_trim
:
bool
):
if
no_stop_trim
or
not
finished_reason
:
return
output
matched
=
finished_reason
.
get
(
"matched"
,
None
)
if
not
matched
:
return
output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if
isinstance
(
finished_reason
,
FINISH_MATCHED_STR
)
and
isinstance
(
output
,
str
):
pos
=
output
.
find
(
finished_reason
.
matched
)
# TODO(lmzheng): handle the case where multiple stop strs are hit
# Trim stop str.
if
isinstance
(
matched
,
str
)
and
isinstance
(
output
,
str
):
pos
=
output
.
find
(
matched
)
return
output
[:
pos
]
if
pos
!=
-
1
else
output
if
isinstance
(
finished_reason
,
FINISH_MATCHED_TOKEN
)
and
isinstance
(
output
,
list
):
# Trim stop token.
if
isinstance
(
matched
,
int
)
and
isinstance
(
output
,
list
):
assert
len
(
output
)
>
0
return
output
[:
-
1
]
return
output
...
...
@@ -125,9 +133,9 @@ class DetokenizerManager:
s
.
decode_ids
=
recv_obj
.
decode_ids
[
i
]
read_ids
.
append
(
self
.
trim_
eos
(
self
.
trim_
matched_stop
(
s
.
decode_ids
[
s
.
surr_offset
:],
recv_obj
.
finished_reason
[
i
],
recv_obj
.
finished_reason
s
[
i
],
recv_obj
.
no_stop_trim
[
i
],
)
)
...
...
@@ -150,7 +158,7 @@ class DetokenizerManager:
for
i
in
range
(
bs
):
s
=
self
.
decode_status
[
recv_obj
.
rids
[
i
]]
new_text
=
read_texts
[
i
][
len
(
surr_texts
[
i
])
:]
if
recv_obj
.
finished_reason
[
i
]
is
None
:
if
recv_obj
.
finished_reason
s
[
i
]
is
None
:
# Streaming chunk: update the decode status
if
len
(
new_text
)
>
0
and
not
new_text
.
endswith
(
"�"
):
s
.
decoded_text
=
s
.
decoded_text
+
new_text
...
...
@@ -161,9 +169,9 @@ class DetokenizerManager:
new_text
=
find_printable_text
(
new_text
)
output_strs
.
append
(
self
.
trim_
eos
(
self
.
trim_
matched_stop
(
s
.
decoded_text
+
new_text
,
recv_obj
.
finished_reason
[
i
],
recv_obj
.
finished_reason
s
[
i
],
recv_obj
.
no_stop_trim
[
i
],
)
)
...
...
@@ -171,9 +179,20 @@ class DetokenizerManager:
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
output_strs
=
output_strs
,
meta_info
=
recv_obj
.
meta_info
,
finished_reason
=
recv_obj
.
finished_reason
,
prompt_tokens
=
recv_obj
.
prompt_tokens
,
completion_tokens
=
recv_obj
.
completion_tokens
,
cached_tokens
=
recv_obj
.
cached_tokens
,
input_token_logprobs_val
=
recv_obj
.
input_token_logprobs_val
,
input_token_logprobs_idx
=
recv_obj
.
input_token_logprobs_idx
,
output_token_logprobs_val
=
recv_obj
.
output_token_logprobs_val
,
output_token_logprobs_idx
=
recv_obj
.
output_token_logprobs_idx
,
input_top_logprobs_val
=
recv_obj
.
input_top_logprobs_val
,
input_top_logprobs_idx
=
recv_obj
.
input_top_logprobs_idx
,
output_top_logprobs_val
=
recv_obj
.
output_top_logprobs_val
,
output_top_logprobs_idx
=
recv_obj
.
output_top_logprobs_idx
,
normalized_prompt_logprob
=
recv_obj
.
normalized_prompt_logprob
,
)
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
a6ca736c
...
...
@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
class
BatchTokenIDOut
:
# The request id
rids
:
List
[
str
]
# The finish reason
finished_reasons
:
List
[
BaseFinishReason
]
# For incremental decoding
# The version id to sync decode status with in detokenizer_manager
vids
:
List
[
int
]
decoded_texts
:
List
[
str
]
...
...
@@ -315,35 +318,61 @@ class BatchTokenIDOut:
read_offsets
:
List
[
int
]
# Only used when `--skip-tokenizer-init`
output_ids
:
Optional
[
List
[
int
]]
# Detokenization configs
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
no_stop_trim
:
List
[
bool
]
# Token counts
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
# Logprobs
input_token_logprobs_val
:
List
[
float
]
input_token_logprobs_idx
:
List
[
int
]
output_token_logprobs_val
:
List
[
float
]
output_token_logprobs_idx
:
List
[
int
]
input_top_logprobs_val
:
List
[
List
]
input_top_logprobs_idx
:
List
[
List
]
output_top_logprobs_val
:
List
[
List
]
output_top_logprobs_idx
:
List
[
List
]
normalized_prompt_logprob
:
List
[
float
]
@
dataclass
class
BatchStrOut
:
# The request id
rids
:
List
[
str
]
# The finish reason
finished_reasons
:
List
[
dict
]
# The output decoded strings
output_strs
:
List
[
str
]
# The meta info
meta_info
:
List
[
Dict
]
# The finish reason
finished_reason
:
List
[
BaseFinishReason
]
# Token counts
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
# Logprobs
input_token_logprobs_val
:
List
[
float
]
input_token_logprobs_idx
:
List
[
int
]
output_token_logprobs_val
:
List
[
float
]
output_token_logprobs_idx
:
List
[
int
]
input_top_logprobs_val
:
List
[
List
]
input_top_logprobs_idx
:
List
[
List
]
output_top_logprobs_val
:
List
[
List
]
output_top_logprobs_idx
:
List
[
List
]
normalized_prompt_logprob
:
List
[
float
]
@
dataclass
class
BatchEmbeddingOut
:
# The request id
rids
:
List
[
str
]
# The finish reason
finished_reasons
:
List
[
BaseFinishReason
]
# The output embedding
embeddings
:
List
[
List
[
float
]]
# The meta info
meta_info
:
List
[
Dict
]
# The finish reason
finished_reason
:
List
[
BaseFinishReason
]
# Token counts
prompt_tokens
:
List
[
int
]
@
dataclass
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
a6ca736c
...
...
@@ -200,6 +200,9 @@ class Req:
origin_input_text
:
str
,
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
return_logprob
:
bool
=
False
,
top_logprobs_num
:
int
=
0
,
stream
:
bool
=
False
,
origin_input_ids_unpadded
:
Optional
[
Tuple
[
int
]]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
input_embeds
:
Optional
[
List
[
List
[
float
]]]
=
None
,
...
...
@@ -217,10 +220,11 @@ class Req:
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
session_id
=
session_id
self
.
input_embeds
=
input_embeds
# Sampling info
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
self
.
input_embeds
=
input_embeds
# Memory pool info
self
.
req_pool_idx
=
None
...
...
@@ -228,8 +232,8 @@ class Req:
# Check finish
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
stream
=
False
self
.
to_abort
=
False
self
.
stream
=
stream
# For incremental decoding
# ----- | --------- read_ids -------|
...
...
@@ -241,13 +245,9 @@ class Req:
# 2: read_offset
# 3: last token
self
.
vid
=
0
# version id to sync decode status with in detokenizer_manager
self
.
decoded_text
=
""
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
read_offset
=
None
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self
.
completion_tokens_wo_jump_forward
=
0
self
.
decoded_text
=
""
# For multimodal inputs
self
.
image_inputs
:
Optional
[
ImageInputs
]
=
None
...
...
@@ -256,22 +256,34 @@ class Req:
self
.
prefix_indices
=
[]
self
.
extend_input_len
=
0
self
.
last_node
=
None
# Chunked prefill
self
.
is_being_chunked
=
0
# For retraction
self
.
is_retracted
=
False
# Logprobs (arguments)
self
.
return_logprob
=
False
self
.
return_logprob
=
return_logprob
self
.
logprob_start_len
=
0
self
.
top_logprobs_num
=
0
self
.
top_logprobs_num
=
top_logprobs_num
# Logprobs (return value)
self
.
normalized_prompt_logprob
=
None
self
.
input_token_logprobs
=
None
self
.
input_top_logprobs
=
None
self
.
output_token_logprobs
=
[]
self
.
output_top_logprobs
=
[]
self
.
input_token_logprobs_val
=
None
self
.
input_token_logprobs_idx
=
None
self
.
input_top_logprobs_val
=
None
self
.
input_top_logprobs_idx
=
None
if
return_logprob
:
self
.
output_token_logprobs_val
=
[]
self
.
output_token_logprobs_idx
=
[]
self
.
output_top_logprobs_val
=
[]
self
.
output_top_logprobs_idx
=
[]
else
:
self
.
output_token_logprobs_val
=
self
.
output_token_logprobs_idx
=
(
self
.
output_top_logprobs_val
)
=
self
.
output_top_logprobs_idx
=
None
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
...
...
@@ -295,8 +307,8 @@ class Req:
else
:
self
.
image_inputs
.
merge
(
image_inputs
)
# whether request reached finished condition
def
finished
(
self
)
->
bool
:
# Whether request reached finished condition
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
...
...
@@ -454,8 +466,10 @@ class Req:
k
=
k
+
1
else
:
break
self
.
output_token_logprobs
=
self
.
output_token_logprobs
[:
k
]
self
.
output_top_logprobs
=
self
.
output_top_logprobs
[:
k
]
self
.
output_token_logprobs_val
=
self
.
output_token_logprobs_val
[:
k
]
self
.
output_token_logprobs_idx
=
self
.
output_token_logprobs_idx
[:
k
]
self
.
output_top_logprobs_val
=
self
.
output_top_logprobs_val
[:
k
]
self
.
output_top_logprobs_idx
=
self
.
output_top_logprobs_idx
[:
k
]
self
.
logprob_start_len
=
prompt_tokens
+
k
self
.
last_update_decode_tokens
=
len
(
self
.
output_ids
)
-
k
...
...
python/sglang/srt/managers/scheduler.py
View file @
a6ca736c
...
...
@@ -515,6 +515,9 @@ class Scheduler:
recv_req
.
input_text
,
recv_req
.
input_ids
,
recv_req
.
sampling_params
,
return_logprob
=
recv_req
.
return_logprob
,
top_logprobs_num
=
recv_req
.
top_logprobs_num
,
stream
=
recv_req
.
stream
,
lora_path
=
recv_req
.
lora_path
,
input_embeds
=
recv_req
.
input_embeds
,
)
...
...
@@ -558,9 +561,6 @@ class Scheduler:
return
# Copy more attributes
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
if
req
.
logprob_start_len
==
-
1
:
...
...
@@ -982,7 +982,6 @@ class Scheduler:
continue
if
req
.
is_being_chunked
<=
0
:
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
...
...
@@ -1035,7 +1034,7 @@ class Scheduler:
# being chunked reqs' prefill is not finished
req
.
is_being_chunked
-=
1
self
.
stream_output
(
batch
.
reqs
,
skip_stream_req
)
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
process_batch_result_decode
(
self
,
batch
:
ScheduleBatch
,
result
):
logits_output
,
next_token_ids
,
bid
=
result
...
...
@@ -1065,7 +1064,6 @@ class Scheduler:
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
...
...
@@ -1073,11 +1071,15 @@ class Scheduler:
self
.
tree_cache
.
cache_finished_req
(
req
)
if
req
.
return_logprob
:
req
.
output_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
req
.
output_token_logprobs_val
.
append
(
next_token_logprobs
[
i
])
req
.
output_token_logprobs_idx
.
append
(
next_token_id
)
if
req
.
top_logprobs_num
>
0
:
req
.
output_top_logprobs
.
append
(
logits_output
.
output_top_logprobs
[
i
])
req
.
output_top_logprobs_val
.
append
(
logits_output
.
output_top_logprobs_val
[
i
]
)
req
.
output_top_logprobs_idx
.
append
(
logits_output
.
output_top_logprobs_idx
[
i
]
)
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
...
...
@@ -1088,7 +1090,7 @@ class Scheduler:
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
)
self
.
token_to_kv_pool
.
free_group_end
()
...
...
@@ -1108,9 +1110,8 @@ class Scheduler:
output
:
LogitsProcessorOutput
,
):
"""Attach logprobs to the return values."""
req
.
output_token_logprobs
.
append
(
(
output
.
next_token_logprobs
[
i
],
next_token_ids
[
i
])
)
req
.
output_token_logprobs_val
.
append
(
output
.
next_token_logprobs
[
i
])
req
.
output_token_logprobs_idx
.
append
(
next_token_ids
[
i
])
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs
=
req
.
extend_input_len
-
req
.
extend_logprob_start_len
...
...
@@ -1118,38 +1119,36 @@ class Scheduler:
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
if
req
.
input_token_logprobs
is
None
:
input_token_logprobs
=
output
.
input_token_logprobs
[
if
req
.
input_token_logprobs
_val
is
None
:
input_token_logprobs
_val
=
output
.
input_token_logprobs
[
pt
:
pt
+
num_input_logprobs
-
1
-
req
.
last_update_decode_tokens
]
input_token_ids
=
req
.
fill_ids
[
input_token_logprobs_idx
=
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
num_input_logprobs
+
1
:
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
]
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
input_token_id
s
=
[
input_token_
logprobs_
id
x
=
[
x
if
x
<
self
.
model_config
.
vocab_size
-
1
else
0
for
x
in
input_token_id
s
for
x
in
input_token_
logprobs_
id
x
]
req
.
input_token_logprobs
=
list
(
zip
(
input_token_logprobs
,
input_token_ids
))
if
(
req
.
logprob_start_len
==
0
):
# The first token does not have logprob, pad it.
req
.
input_token_logprobs
=
[
(
None
,
req
.
fill_ids
[
0
])
]
+
req
.
input_token_logprobs
input_token_logprobs_val
=
[
None
]
+
input_token_logprobs_val
input_token_logprobs_idx
=
[
req
.
fill_ids
[
0
]]
+
input_token_logprobs_idx
req
.
input_token_logprobs_val
=
input_token_logprobs_val
req
.
input_token_logprobs_idx
=
input_token_logprobs_idx
if
req
.
last_update_decode_tokens
!=
0
:
# Some decode tokens are re-computed in an extend batch
req
.
output_token_logprobs
.
extend
(
list
(
zip
(
req
.
output_token_logprobs_val
.
extend
(
output
.
input_token_logprobs
[
pt
+
num_input_logprobs
...
...
@@ -1158,132 +1157,156 @@ class Scheduler:
+
num_input_logprobs
-
1
],
)
req
.
output_token_logprobs_idx
.
extend
(
req
.
fill_ids
[
len
(
req
.
fill_ids
)
-
req
.
last_update_decode_tokens
:
len
(
req
.
fill_ids
)
],
)
)
]
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
input_top_logprobs
is
None
:
req
.
input_top_logprobs
=
output
.
input_top_logprobs
[
i
]
if
req
.
input_top_logprobs_val
is
None
:
req
.
input_top_logprobs_val
=
output
.
input_top_logprobs_val
[
i
]
req
.
input_top_logprobs_idx
=
output
.
input_top_logprobs_idx
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
input_top_logprobs
=
[
None
]
+
req
.
input_top_logprobs
req
.
input_top_logprobs_val
=
[
None
]
+
req
.
input_top_logprobs_val
req
.
input_top_logprobs_idx
=
[
None
]
+
req
.
input_top_logprobs_idx
if
req
.
last_update_decode_tokens
!=
0
:
req
.
output_top_logprobs
.
extend
(
output
.
input_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
:]
req
.
output_top_logprobs_val
.
extend
(
output
.
input_top_logprobs_val
[
i
][
-
req
.
last_update_decode_tokens
:]
)
req
.
output_top_logprobs_idx
.
extend
(
output
.
input_top_logprobs_idx
[
i
][
-
req
.
last_update_decode_tokens
:]
)
req
.
output_top_logprobs
.
append
(
output
.
output_top_logprobs
[
i
])
req
.
output_top_logprobs_val
.
append
(
output
.
output_top_logprobs_val
[
i
])
req
.
output_top_logprobs_idx
.
append
(
output
.
output_top_logprobs_idx
[
i
])
return
num_input_logprobs
def
stream_output
(
self
,
reqs
:
List
[
Req
],
skip_req
:
Optional
[
Req
]
=
None
):
def
stream_output
(
self
,
reqs
:
List
[
Req
],
return_logprob
:
bool
,
skip_req
:
Optional
[
Req
]
=
None
):
"""Stream the output to detokenizer."""
output_
rids
=
[]
output_meta_info
:
List
[
dict
]
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
rids
=
[]
finished_reasons
:
List
[
BaseFinishReason
]
=
[]
if
self
.
is_generation
:
output_
vids
=
[]
vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_
read_offsets
=
[]
decode_ids_list
=
[]
read_offsets
=
[]
output_ids
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_no_stop_trim
=
[]
else
:
# embedding or reward model
output_embeddings
=
[]
is_stream_iter
=
self
.
forward_ct_decode
%
self
.
stream_interval
==
0
skip_special_tokens
=
[]
spaces_between_special_tokens
=
[]
no_stop_trim
=
[]
prompt_tokens
=
[]
completion_tokens
=
[]
cached_tokens
=
[]
if
return_logprob
:
input_token_logprobs_val
=
[]
input_token_logprobs_idx
=
[]
output_token_logprobs_val
=
[]
output_token_logprobs_idx
=
[]
input_top_logprobs_val
=
[]
input_top_logprobs_idx
=
[]
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
normalized_prompt_logprob
=
[]
else
:
input_token_logprobs_val
=
input_token_logprobs_idx
=
(
output_token_logprobs_val
)
=
output_token_logprobs_idx
=
input_top_logprobs_val
=
(
input_top_logprobs_idx
)
=
output_top_logprobs_val
=
output_top_logprobs_idx
=
(
normalized_prompt_logprob
)
=
None
for
req
in
reqs
:
if
req
is
skip_req
:
continue
# TODO(lianmin): revisit this for overlap + retract + stream
if
req
.
finished
()
or
(
req
.
stream
and
(
is_stream_iter
or
len
(
req
.
output_ids
)
==
1
)
if
(
req
.
finished
()
# If stream, follow the given stream_interval
or
(
req
.
stream
and
len
(
req
.
output_ids
)
%
self
.
stream_interval
==
0
)
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
or
(
not
req
.
stream
and
len
(
req
.
output_ids
)
%
50
==
0
)
):
output_rids
.
append
(
req
.
rid
)
output_finished_reason
.
append
(
req
.
finished_reason
)
if
self
.
is_generation
:
output_vids
.
append
(
req
.
vid
)
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
)
vids
.
append
(
req
.
vid
)
decoded_texts
.
append
(
req
.
decoded_text
)
read
_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read
_ids
)
output_
read_offsets
.
append
(
read_offset
)
decode
_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
decode_ids_list
.
append
(
decode
_ids
)
read_offsets
.
append
(
read_offset
)
if
self
.
skip_tokenizer_init
:
output_ids
.
append
(
req
.
output_ids
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
output_spaces_between_special_tokens
.
append
(
skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
output_no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"cached_tokens"
:
req
.
cached_tokens
,
"finish_reason"
:
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
is
not
None
else
None
),
}
if
req
.
return_logprob
:
(
meta_info
[
"input_token_logprobs"
],
meta_info
[
"output_token_logprobs"
],
meta_info
[
"input_top_logprobs"
],
meta_info
[
"output_top_logprobs"
],
meta_info
[
"normalized_prompt_logprob"
],
)
=
(
req
.
input_token_logprobs
,
req
.
output_token_logprobs
,
req
.
input_top_logprobs
,
req
.
output_top_logprobs
,
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
else
:
# embedding or reward model
output_embeddings
.
append
(
req
.
embedding
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
}
output_meta_info
.
append
(
meta_info
)
no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
completion_tokens
.
append
(
len
(
req
.
output_ids
))
cached_tokens
.
append
(
req
.
cached_tokens
)
if
return_logprob
:
input_token_logprobs_val
.
append
(
req
.
input_token_logprobs_val
)
input_token_logprobs_idx
.
append
(
req
.
input_token_logprobs_idx
)
output_token_logprobs_val
.
append
(
req
.
output_token_logprobs_val
)
output_token_logprobs_idx
.
append
(
req
.
output_token_logprobs_idx
)
input_top_logprobs_val
.
append
(
req
.
input_top_logprobs_val
)
input_top_logprobs_idx
.
append
(
req
.
input_top_logprobs_idx
)
output_top_logprobs_val
.
append
(
req
.
output_top_logprobs_val
)
output_top_logprobs_idx
.
append
(
req
.
output_top_logprobs_idx
)
normalized_prompt_logprob
.
append
(
req
.
normalized_prompt_logprob
)
# Send to detokenizer
if
output_rids
:
if
self
.
is_generation
:
if
rids
:
self
.
send_to_detokenizer
.
send_pyobj
(
BatchTokenIDOut
(
output_rids
,
output_vids
,
rids
,
finished_reasons
,
vids
,
decoded_texts
,
output_read_ids
,
output_
read_offsets
,
decode_ids_list
,
read_offsets
,
output_ids
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished_reason
,
output_no_stop_trim
,
skip_special_tokens
,
spaces_between_special_tokens
,
no_stop_trim
,
prompt_tokens
,
completion_tokens
,
cached_tokens
,
input_token_logprobs_val
,
input_token_logprobs_idx
,
output_token_logprobs_val
,
output_token_logprobs_idx
,
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
normalized_prompt_logprob
,
)
)
else
:
# embedding or reward model
embeddings
=
[]
prompt_tokens
=
[]
for
req
in
reqs
:
assert
req
.
finished
()
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
())
embeddings
.
append
(
req
.
embedding
)
prompt_tokens
.
append
(
len
(
req
.
origin_input_ids
))
self
.
send_to_detokenizer
.
send_pyobj
(
BatchEmbeddingOut
(
output_rids
,
output_embeddings
,
output_meta_info
,
output_finished_reason
,
)
BatchEmbeddingOut
(
rids
,
finished_reasons
,
embeddings
,
prompt_tokens
)
)
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
a6ca736c
...
...
@@ -22,7 +22,7 @@ import signal
import
sys
import
time
import
uuid
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
fastapi
import
uvloop
...
...
@@ -76,6 +76,7 @@ class ReqState:
out_list
:
List
finished
:
bool
event
:
asyncio
.
Event
obj
:
Any
# For metrics
created_time
:
float
...
...
@@ -283,7 +284,7 @@ class TokenizerManager:
):
"""Wait for the response of one request."""
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
,
created_time
=
created_time
)
state
=
ReqState
([],
False
,
event
,
obj
,
created_time
=
created_time
)
self
.
rid_to_state
[
obj
.
rid
]
=
state
while
True
:
...
...
@@ -295,14 +296,6 @@ class TokenizerManager:
raise
ValueError
(
f
"Abort request
{
obj
.
rid
}
"
)
continue
if
isinstance
(
obj
,
GenerateReqInput
):
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
,
obj
.
top_logprobs_num
,
obj
.
return_text_in_logprobs
,
)
else
:
# isinstance(obj, (EmbeddingReqInput,))
out
=
state
.
out_list
[
-
1
]
state
.
out_list
=
[]
...
...
@@ -315,7 +308,13 @@ class TokenizerManager:
break
state
.
event
.
clear
()
if
obj
.
stream
:
yield
out
else
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
self
.
abort_request
(
obj
.
rid
)
raise
ValueError
(
f
"Abort request
{
obj
.
rid
}
"
)
async
def
_handle_batch_request
(
self
,
...
...
@@ -609,29 +608,55 @@ class TokenizerManager:
if
state
is
None
:
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
meta_info
=
{
"id"
:
rid
,
"finish_reason"
:
recv_obj
.
finished_reasons
[
i
],
"prompt_tokens"
:
recv_obj
.
prompt_tokens
[
i
],
}
if
getattr
(
state
.
obj
,
"return_logprob"
,
False
):
self
.
convert_logprob_style
(
meta_info
,
state
.
obj
.
top_logprobs_num
,
state
.
obj
.
return_text_in_logprobs
,
recv_obj
,
i
,
)
if
isinstance
(
recv_obj
,
BatchStrOut
):
out_dict
=
{
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
{
**
meta_info
,
"completion_tokens"
:
recv_obj
.
completion_tokens
[
i
],
"cached_tokens"
:
recv_obj
.
cached_tokens
[
i
],
},
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
out_dict
=
{
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
{
**
meta_info
,
"completion_tokens"
:
recv_obj
.
completion_tokens
[
i
],
"cached_tokens"
:
recv_obj
.
cached_tokens
[
i
],
},
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
"embedding"
:
recv_obj
.
embeddings
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
]
,
"meta_info"
:
meta_info
,
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
finished
=
recv_obj
.
finished_reason
s
[
i
]
is
not
None
state
.
event
.
set
()
if
self
.
enable_metrics
:
completion_tokens
=
recv_obj
.
meta_info
[
i
][
"completion_tokens"
]
completion_tokens
=
(
recv_obj
.
completion_tokens
[
i
]
if
recv_obj
.
completion_tokens
else
0
)
if
state
.
first_token_time
is
None
:
state
.
first_token_time
=
time
.
time
()
...
...
@@ -647,7 +672,7 @@ class TokenizerManager:
if
state
.
finished
:
self
.
metrics_collector
.
inc_prompt_tokens
(
recv_obj
.
meta_info
[
i
][
"
prompt_tokens
"
]
recv_obj
.
prompt_tokens
[
i
]
)
self
.
metrics_collector
.
inc_generation_tokens
(
completion_tokens
...
...
@@ -696,57 +721,73 @@ class TokenizerManager:
def
convert_logprob_style
(
self
,
ret
:
dict
,
return_logprob
:
bool
,
meta_info
:
dict
,
top_logprobs_num
:
int
,
return_text_in_logprobs
:
bool
,
recv_obj
:
BatchStrOut
,
recv_obj_index
:
int
,
):
if
return_logprob
:
ret
[
"meta_info"
][
"input_token_logprobs"
]
=
self
.
detokenize_logprob_tokens
(
ret
[
"meta_info"
][
"input_token_logprobs"
],
return_text_in_logprobs
meta_info
[
"input_token_logprobs"
]
=
self
.
detokenize_logprob_tokens
(
recv_obj
.
input_token_logprobs_val
[
recv_obj_index
],
recv_obj
.
input_token_logprobs_idx
[
recv_obj_index
],
return_text_in_logprobs
,
)
ret
[
"meta_info"
][
"output_token_logprobs"
]
=
self
.
detokenize_logprob_tokens
(
ret
[
"meta_info"
][
"output_token_logprobs"
],
return_text_in_logprobs
meta_info
[
"output_token_logprobs"
]
=
self
.
detokenize_logprob_tokens
(
recv_obj
.
output_token_logprobs_val
[
recv_obj_index
],
recv_obj
.
output_token_logprobs_idx
[
recv_obj_index
],
return_text_in_logprobs
,
)
meta_info
[
"normalized_prompt_logprob"
]
=
recv_obj
.
normalized_prompt_logprob
[
recv_obj_index
]
if
top_logprobs_num
>
0
:
ret
[
"
meta_info
"
]
[
"input_top_logprobs"
]
=
(
self
.
detokenize_top_logprobs_tokens
(
ret
[
"meta_info"
][
"input_top_logprobs"
],
meta_info
[
"input_top_logprobs"
]
=
self
.
detokenize_top_logprobs_tokens
(
recv_obj
.
input_top_logprobs_val
[
recv_obj_index
],
recv_obj
.
input_top_logprobs_idx
[
recv_obj_index
],
return_text_in_logprobs
,
)
meta_info
[
"output_top_logprobs"
]
=
self
.
detokenize_top_logprobs_tokens
(
recv_obj
.
output_top_logprobs_val
[
recv_obj_index
],
recv_obj
.
output_top_logprobs_idx
[
recv_obj_index
],
return_text_in_logprobs
,
)
ret
[
"meta_info"
][
"output_top_logprobs"
]
=
(
self
.
detokenize_top_logprobs_tokens
(
ret
[
"meta_info"
][
"output_top_logprobs"
],
return_text_in_logprobs
)
)
return
ret
def
detokenize_logprob_tokens
(
self
,
token_logprobs
:
List
[
Tuple
[
float
,
int
]],
decode_to_text
:
bool
self
,
token_logprobs_val
:
List
[
float
],
token_logprobs_idx
:
List
[
int
],
decode_to_text
:
bool
,
):
# TODO(lianmin): This should run on DetokenizerManager
if
not
decode_to_text
:
return
[(
logprob
,
token_id
,
None
)
for
logprob
,
token_id
in
token_logprobs
]
assert
self
.
tokenizer
is
not
None
token_ids
=
[
tid
for
_
,
tid
in
token_logprobs
]
token_texts
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
return
[
(
logprob
,
token_id
,
token_text
)
for
(
logprob
,
token_id
),
token_text
in
zip
(
token_logprobs
,
token_
texts
)
(
logprob
,
token_id
,
None
)
for
logprob
,
token_id
in
zip
(
token_logprobs
_val
,
token_
logprobs_idx
)
]
else
:
assert
self
.
tokenizer
is
not
None
token_texts
=
self
.
tokenizer
.
batch_decode
(
token_logprobs_idx
)
return
list
(
zip
(
token_logprobs_val
,
token_logprobs_idx
,
token_texts
))
def
detokenize_top_logprobs_tokens
(
self
,
top_logprobs
,
decode_to_text
:
bool
):
def
detokenize_top_logprobs_tokens
(
self
,
token_logprobs_val
:
List
[
float
],
token_logprobs_idx
:
List
[
int
],
decode_to_text
:
bool
,
):
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
# We should batch all top-k tokens in all positions.
for
i
,
token_top_logprobs
in
enumerate
(
top_logprobs
):
if
token_top_logprobs
:
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
token_top_logprobs
,
decode_to_text
ret
=
[]
for
i
in
range
(
len
(
token_logprobs_val
)):
if
token_logprobs_val
[
i
]:
ret
.
append
(
self
.
detokenize_logprob_tokens
(
token_logprobs_val
[
i
],
token_logprobs_idx
[
i
],
decode_to_text
)
return
top_logprobs
)
else
:
ret
.
append
(
None
)
return
ret
class
SignalHandler
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
a6ca736c
...
...
@@ -400,9 +400,14 @@ class CudaGraphRunner:
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
logits_output
.
output_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
(
logits_output
.
output_top_logprobs_val
,
logits_output
.
output_top_logprobs_idx
,
)
=
LogitsProcessor
.
get_top_logprobs
(
next_token_logprobs
,
logits_metadata
)[
1
]
)[
2
:
4
]
else
:
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
...
...
python/sglang/test/test_utils.py
View file @
a6ca736c
...
...
@@ -720,13 +720,13 @@ def run_and_check_memory_leak(
# Clean up everything
kill_process_tree
(
process
.
pid
)
kill_process_tree
(
process
.
pid
)
stdout
.
close
()
stderr
.
close
()
if
os
.
path
.
exists
(
STDOUT_FILENAME
):
os
.
remove
(
STDOUT_FILENAME
)
if
os
.
path
.
exists
(
STDERR_FILENAME
):
os
.
remove
(
STDERR_FILENAME
)
kill_process_tree
(
process
.
pid
)
t
.
join
()
# Assert success
...
...
@@ -734,7 +734,7 @@ def run_and_check_memory_leak(
has_leak
=
False
has_abort
=
False
for
line
in
output_lines
:
if
"
The server is fired
"
in
line
:
if
"
Uvicorn running
"
in
line
:
has_new_server
=
True
if
"leak"
in
line
:
has_leak
=
True
...
...
test/srt/test_json_constrained.py
View file @
a6ca736c
...
...
@@ -95,15 +95,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
self
.
assertIsInstance
(
js_obj
[
"name"
],
str
)
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
# Make sure jump forward is triggered
# NOTE: The overlap scheduler does not support jump forward so we only do this test
# when --disable-overlap-schedule is set.
if
self
.
check_jump_forward
:
self
.
assertGreater
(
ret
[
"meta_info"
][
"completion_tokens"
],
ret
[
"meta_info"
][
"completion_tokens_wo_jump_forward"
],
)
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment