Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
48761171
Unverified
Commit
48761171
authored
Oct 13, 2024
by
Ying Sheng
Committed by
GitHub
Oct 13, 2024
Browse files
[Fix] fix eos trim inconsistency (#1650)
parent
c3f2fc5a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
77 additions
and
27 deletions
+77
-27
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+31
-10
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-0
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+32
-17
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+1
-0
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+2
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+7
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
48761171
...
...
@@ -18,7 +18,7 @@ limitations under the License.
import
dataclasses
import
logging
from
collections
import
OrderedDict
from
typing
import
List
from
typing
import
List
,
Union
import
zmq
...
...
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut
,
UpdateWeightReqOutput
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
,
FINISH_MATCHED_TOKEN
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
configure_logger
,
kill_parent_process
from
sglang.utils
import
find_printable_text
,
get_exception_traceback
...
...
@@ -75,6 +75,21 @@ class DetokenizerManager:
self
.
decode_status
=
LimitedCapacityDict
()
def
trim_eos
(
self
,
output
:
Union
[
str
,
List
[
int
]],
finished_reason
,
no_eos_trim
):
if
no_eos_trim
:
return
output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if
isinstance
(
finished_reason
,
FINISH_MATCHED_STR
)
and
isinstance
(
output
,
str
):
pos
=
output
.
find
(
finished_reason
.
matched
)
return
output
[:
pos
]
if
pos
!=
-
1
else
output
if
isinstance
(
finished_reason
,
FINISH_MATCHED_TOKEN
)
and
isinstance
(
output
,
list
):
assert
len
(
output
)
>
0
return
output
[:
-
1
]
return
output
def
event_loop
(
self
):
"""The event loop that handles requests"""
...
...
@@ -122,7 +137,13 @@ class DetokenizerManager:
s
=
self
.
decode_status
[
rid
]
s
.
decode_ids
=
recv_obj
.
decode_ids
[
i
]
read_ids
.
append
(
s
.
decode_ids
[
s
.
surr_offset
:])
read_ids
.
append
(
self
.
trim_eos
(
s
.
decode_ids
[
s
.
surr_offset
:],
recv_obj
.
finished_reason
[
i
],
recv_obj
.
no_eos_trim
[
i
],
)
)
surr_ids
.
append
(
s
.
decode_ids
[
s
.
surr_offset
:
s
.
read_offset
])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
...
...
@@ -152,13 +173,13 @@ class DetokenizerManager:
else
:
new_text
=
find_printable_text
(
new_text
)
output_strs
.
append
(
s
.
decoded_text
+
new_text
)
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if
isinstance
(
recv_obj
.
finished_reason
[
i
],
FINISH_MATCHED_STR
):
pos
=
output_strs
[
i
].
find
(
recv_obj
.
finished_reason
[
i
].
matched
)
if
pos
!=
-
1
:
output_strs
[
i
]
=
output_strs
[
i
][:
pos
]
output_strs
.
append
(
self
.
trim_eos
(
s
.
decoded_text
+
new_text
,
recv_obj
.
finished_reason
[
i
],
recv_obj
.
no_eos_trim
[
i
],
)
)
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
48761171
...
...
@@ -295,6 +295,7 @@ class BatchTokenIDOut:
spaces_between_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
no_eos_trim
:
List
[
bool
]
@
dataclass
...
...
python/sglang/srt/managers/scheduler.py
View file @
48761171
...
...
@@ -883,6 +883,7 @@ class Scheduler:
output_read_offsets
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_no_eos_trim
=
[]
else
:
# embedding or reward model
output_embeddings
=
[]
unfinished_indices
=
[]
...
...
@@ -914,6 +915,7 @@ class Scheduler:
output_spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
)
output_no_eos_trim
.
append
(
req
.
sampling_params
.
no_eos_trim
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
...
...
@@ -961,6 +963,7 @@ class Scheduler:
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished_reason
,
output_no_eos_trim
,
)
)
else
:
# embedding or reward model
...
...
python/sglang/srt/openai_api/adapter.py
View file @
48761171
...
...
@@ -493,23 +493,38 @@ def v1_generate_request(
top_logprobs_nums
.
append
(
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
)
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"min_new_tokens"
:
request
.
min_tokens
,
"stop"
:
request
.
stop
,
"stop_token_ids"
:
request
.
stop_token_ids
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
}
)
sampling_params
=
[]
if
isinstance
(
request
.
no_eos_trim
,
list
):
num_reqs
=
len
(
request
.
prompt
)
else
:
num_reqs
=
1
for
i
in
range
(
num_reqs
):
sampling_params
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"min_new_tokens"
:
request
.
min_tokens
,
"stop"
:
request
.
stop
,
"stop_token_ids"
:
request
.
stop_token_ids
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
"no_eos_trim"
:
(
request
.
no_eos_trim
if
not
isinstance
(
request
.
no_eos_trim
,
list
)
else
request
.
no_eos_trim
[
i
]
),
}
)
if
num_reqs
==
1
:
sampling_params_list
.
append
(
sampling_params
[
0
])
else
:
sampling_params_list
.
append
(
sampling_params
)
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
...
...
python/sglang/srt/openai_api/protocol.py
View file @
48761171
...
...
@@ -174,6 +174,7 @@ class CompletionRequest(BaseModel):
min_tokens
:
int
=
0
repetition_penalty
:
Optional
[
float
]
=
1.0
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
no_eos_trim
:
Union
[
bool
,
List
[
bool
]]
=
False
class
CompletionResponseChoice
(
BaseModel
):
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
48761171
...
...
@@ -40,6 +40,7 @@ class SamplingParams:
regex
:
Optional
[
str
]
=
None
,
n
:
int
=
1
,
json_schema
:
Optional
[
str
]
=
None
,
no_eos_trim
:
bool
=
False
,
)
->
None
:
self
.
temperature
=
temperature
self
.
top_p
=
top_p
...
...
@@ -60,6 +61,7 @@ class SamplingParams:
self
.
regex
=
regex
self
.
n
=
n
self
.
json_schema
=
json_schema
self
.
no_eos_trim
=
no_eos_trim
# Process some special cases
if
self
.
temperature
<
_SAMPLING_EPS
:
...
...
python/sglang/srt/utils.py
View file @
48761171
...
...
@@ -690,3 +690,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
prof
.
export_chrome_trace
(
f
"trace/
{
name
}
_
{
step_counter
}
.json"
)
step_counter
+=
1
return
result
def
first_rank_print
(
*
args
,
**
kwargs
):
if
torch
.
cuda
.
current_device
()
==
0
:
print
(
*
args
,
**
kwargs
)
else
:
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment