Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb99aaa5
Unverified
Commit
fb99aaa5
authored
Oct 25, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 25, 2024
Browse files
[Fix] Fix --skip-tokenizer-init (#1798)
parent
b77a02cd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
25 deletions
+50
-25
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+2
-5
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-5
test/srt/test_skip_tokenizer_init.py
test/srt/test_skip_tokenizer_init.py
+25
-11
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
fb99aaa5
...
...
@@ -115,12 +115,9 @@ class DetokenizerManager:
elif
isinstance
(
recv_obj
,
GetMemPoolSizeReqOutput
):
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
elif
self
.
tokenizer
is
None
:
# If the tokenizer is skipped, no detokenization is needed
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
else
:
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
# Initialize decode status
...
...
python/sglang/srt/managers/io_struct.py
View file @
fb99aaa5
...
...
@@ -294,6 +294,8 @@ class BatchTokenIDOut:
decoded_texts
:
List
[
str
]
decode_ids
:
List
[
int
]
read_offsets
:
List
[
int
]
# Only used when `--skip-tokenizer-init`
output_ids
:
Optional
[
List
[
int
]]
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
fb99aaa5
...
...
@@ -104,6 +104,7 @@ class Scheduler:
self
.
lora_paths
=
server_args
.
lora_paths
self
.
max_loras_per_batch
=
server_args
.
max_loras_per_batch
self
.
enable_overlap
=
server_args
.
enable_overlap_schedule
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
...
...
@@ -112,8 +113,18 @@ class Scheduler:
self
.
recv_from_tokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_tokenizer
.
bind
(
f
"ipc://
{
port_args
.
scheduler_input_ipc_name
}
"
)
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"ipc://
{
port_args
.
detokenizer_ipc_name
}
"
)
if
server_args
.
skip_tokenizer_init
:
# Directly send to the tokenizer/api
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"ipc://
{
port_args
.
tokenizer_ipc_name
}
"
)
else
:
# Send to the detokenizer
self
.
send_to_detokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_detokenizer
.
connect
(
f
"ipc://
{
port_args
.
detokenizer_ipc_name
}
"
)
else
:
self
.
recv_from_tokenizer
=
None
self
.
send_to_detokenizer
=
SimpleNamespace
(
send_pyobj
=
lambda
x
:
None
)
...
...
@@ -734,7 +745,7 @@ class Scheduler:
)
else
:
logits_output
=
None
if
self
.
tokenizer
is
not
None
:
if
self
.
skip_
tokenizer
_init
:
next_token_ids
=
torch
.
full
(
(
batch
.
batch_size
(),),
self
.
tokenizer
.
eos_token_id
)
...
...
@@ -950,13 +961,14 @@ class Scheduler:
def
stream_output
(
self
,
reqs
:
List
[
Req
]):
"""Stream the output to detokenizer."""
output_rids
=
[]
output_meta_info
=
[]
output_meta_info
:
List
[
dict
]
=
[]
output_finished_reason
:
List
[
BaseFinishReason
]
=
[]
if
self
.
is_generation
:
output_vids
=
[]
decoded_texts
=
[]
output_read_ids
=
[]
output_read_offsets
=
[]
output_ids
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_no_stop_trim
=
[]
...
...
@@ -977,6 +989,8 @@ class Scheduler:
read_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
output_read_ids
.
append
(
read_ids
)
output_read_offsets
.
append
(
read_offset
)
if
self
.
skip_tokenizer_init
:
output_ids
.
append
(
req
.
output_ids
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
...
...
@@ -1028,6 +1042,7 @@ class Scheduler:
decoded_texts
,
output_read_ids
,
output_read_offsets
,
output_ids
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
fb99aaa5
...
...
@@ -571,7 +571,7 @@ class TokenizerManager:
def
create_abort_task
(
self
,
obj
:
GenerateReqInput
):
# Abort the request if the client is disconnected.
async
def
abort_request
():
await
asyncio
.
sleep
(
3
)
await
asyncio
.
sleep
(
1
)
if
obj
.
is_single
:
self
.
abort_request
(
obj
.
rid
)
else
:
...
...
@@ -621,11 +621,8 @@ class TokenizerManager:
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
read_start
=
0
if
i
==
0
else
recv_obj
.
read_offsets
[
i
-
1
]
out_dict
=
{
"token_ids"
:
recv_obj
.
decode_ids
[
read_start
:
recv_obj
.
read_offsets
[
i
]
],
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
...
...
test/srt/test_skip_tokenizer_init.py
View file @
fb99aaa5
...
...
@@ -29,21 +29,15 @@ class TestSkipTokenizerInit(unittest.TestCase):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
max_new_tokens
=
32
input_ids
=
[
128000
,
791
,
6864
,
315
,
9822
,
374
]
# The capital of France is
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
[
119689
,
50650
,
18291
,
30061
,
5316
,
26951
,
119690
,
],
# The capital of France is
"input_ids"
:
input_ids
,
"sampling_params"
:
{
"temperature"
:
0
if
n
==
1
else
0.5
,
"max_new_tokens"
:
32
,
"max_new_tokens"
:
max_new_tokens
,
"n"
:
n
,
"stop_token_ids"
:
[
119690
],
},
...
...
@@ -53,7 +47,27 @@ class TestSkipTokenizerInit(unittest.TestCase):
"logprob_start_len"
:
0
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
ret
=
response
.
json
()
print
(
json
.
dumps
(
ret
))
def
assert_one_item
(
item
):
assert
len
(
item
[
"token_ids"
])
==
item
[
"meta_info"
][
"completion_tokens"
]
assert
len
(
item
[
"token_ids"
])
==
max_new_tokens
assert
item
[
"meta_info"
][
"prompt_tokens"
]
==
len
(
input_ids
)
if
return_logprob
:
assert
len
(
item
[
"meta_info"
][
"input_token_logprobs"
])
==
len
(
input_ids
),
f
'
{
len
(
item
[
"meta_info"
][
"input_token_logprobs"
])
}
vs. f
{
len
(
input_ids
)
}
'
assert
len
(
item
[
"meta_info"
][
"output_token_logprobs"
])
==
max_new_tokens
if
n
==
1
:
assert_one_item
(
ret
)
else
:
assert
len
(
ret
)
==
n
for
i
in
range
(
n
):
assert_one_item
(
ret
[
i
])
print
(
"="
*
100
)
def
test_simple_decode
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment