Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
35bdb485
Unverified
Commit
35bdb485
authored
Dec 29, 2024
by
Shi Shuai
Committed by
GitHub
Dec 29, 2024
Browse files
[Feature] Get Token IDs with Engine.generate() (#2636)
Co-authored-by:
Chayenne
<
zhaochen20@outlook.com
>
parent
b085e06b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
92 additions
and
2 deletions
+92
-2
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+2
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+7
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+9
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+7
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_engine_token_ids.py
test/srt/test_engine_token_ids.py
+59
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
35bdb485
...
@@ -181,6 +181,8 @@ class DetokenizerManager:
...
@@ -181,6 +181,8 @@ class DetokenizerManager:
finished_reasons
=
recv_obj
.
finished_reasons
,
finished_reasons
=
recv_obj
.
finished_reasons
,
output_strs
=
output_strs
,
output_strs
=
output_strs
,
prompt_tokens
=
recv_obj
.
prompt_tokens
,
prompt_tokens
=
recv_obj
.
prompt_tokens
,
origin_input_ids
=
recv_obj
.
origin_input_ids
,
output_ids
=
recv_obj
.
output_ids
,
completion_tokens
=
recv_obj
.
completion_tokens
,
completion_tokens
=
recv_obj
.
completion_tokens
,
cached_tokens
=
recv_obj
.
cached_tokens
,
cached_tokens
=
recv_obj
.
cached_tokens
,
input_token_logprobs_val
=
recv_obj
.
input_token_logprobs_val
,
input_token_logprobs_val
=
recv_obj
.
input_token_logprobs_val
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
35bdb485
...
@@ -323,7 +323,9 @@ class BatchTokenIDOut:
...
@@ -323,7 +323,9 @@ class BatchTokenIDOut:
decoded_texts
:
List
[
str
]
decoded_texts
:
List
[
str
]
decode_ids
:
List
[
int
]
decode_ids
:
List
[
int
]
read_offsets
:
List
[
int
]
read_offsets
:
List
[
int
]
# Only used when `--skip-tokenizer-init`
# Only used when --return-token-ids` is set
origin_input_ids
:
Optional
[
List
[
int
]]
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
output_ids
:
Optional
[
List
[
int
]]
output_ids
:
Optional
[
List
[
int
]]
# Detokenization configs
# Detokenization configs
skip_special_tokens
:
List
[
bool
]
skip_special_tokens
:
List
[
bool
]
...
@@ -354,6 +356,10 @@ class BatchStrOut:
...
@@ -354,6 +356,10 @@ class BatchStrOut:
# The output decoded strings
# The output decoded strings
output_strs
:
List
[
str
]
output_strs
:
List
[
str
]
# The token ids
origin_input_ids
:
Optional
[
List
[
int
]]
output_ids
:
Optional
[
List
[
int
]]
# Token counts
# Token counts
prompt_tokens
:
List
[
int
]
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
35bdb485
...
@@ -1218,6 +1218,7 @@ class Scheduler:
...
@@ -1218,6 +1218,7 @@ class Scheduler:
decode_ids_list
=
[]
decode_ids_list
=
[]
read_offsets
=
[]
read_offsets
=
[]
output_ids
=
[]
output_ids
=
[]
origin_input_ids
=
[]
skip_special_tokens
=
[]
skip_special_tokens
=
[]
spaces_between_special_tokens
=
[]
spaces_between_special_tokens
=
[]
...
@@ -1266,8 +1267,14 @@ class Scheduler:
...
@@ -1266,8 +1267,14 @@ class Scheduler:
decode_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
decode_ids
,
read_offset
=
req
.
init_incremental_detokenize
()
decode_ids_list
.
append
(
decode_ids
)
decode_ids_list
.
append
(
decode_ids
)
read_offsets
.
append
(
read_offset
)
read_offsets
.
append
(
read_offset
)
if
self
.
skip_tokenizer_init
:
if
self
.
skip_tokenizer_init
or
self
.
server_args
.
return_token_ids
:
output_ids
.
append
(
req
.
output_ids
)
output_ids
.
append
(
req
.
output_ids
)
else
:
output_ids
=
None
if
self
.
server_args
.
return_token_ids
:
origin_input_ids
.
append
(
req
.
origin_input_ids
)
else
:
origin_input_ids
=
None
skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
spaces_between_special_tokens
.
append
(
spaces_between_special_tokens
.
append
(
req
.
sampling_params
.
spaces_between_special_tokens
req
.
sampling_params
.
spaces_between_special_tokens
...
@@ -1299,6 +1306,7 @@ class Scheduler:
...
@@ -1299,6 +1306,7 @@ class Scheduler:
decoded_texts
,
decoded_texts
,
decode_ids_list
,
decode_ids_list
,
read_offsets
,
read_offsets
,
origin_input_ids
,
output_ids
,
output_ids
,
skip_special_tokens
,
skip_special_tokens
,
spaces_between_special_tokens
,
spaces_between_special_tokens
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
35bdb485
...
@@ -663,6 +663,13 @@ class TokenizerManager:
...
@@ -663,6 +663,13 @@ class TokenizerManager:
"text"
:
recv_obj
.
output_strs
[
i
],
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
meta_info
,
"meta_info"
:
meta_info
,
}
}
if
self
.
server_args
.
return_token_ids
:
out_dict
.
update
(
{
"input_ids"
:
recv_obj
.
origin_input_ids
[
i
],
"output_ids"
:
recv_obj
.
output_ids
[
i
],
}
)
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
out_dict
=
{
out_dict
=
{
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"token_ids"
:
recv_obj
.
output_ids
[
i
],
...
...
python/sglang/srt/server_args.py
View file @
35bdb485
...
@@ -54,6 +54,7 @@ class ServerArgs:
...
@@ -54,6 +54,7 @@ class ServerArgs:
chat_template
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
is_embedding
:
bool
=
False
is_embedding
:
bool
=
False
revision
:
Optional
[
str
]
=
None
revision
:
Optional
[
str
]
=
None
return_token_ids
:
bool
=
False
# Port for the HTTP server
# Port for the HTTP server
host
:
str
=
"127.0.0.1"
host
:
str
=
"127.0.0.1"
...
@@ -280,6 +281,12 @@ class ServerArgs:
...
@@ -280,6 +281,12 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"If set, skip init tokenizer and pass input_ids in generate request"
,
help
=
"If set, skip init tokenizer and pass input_ids in generate request"
,
)
)
parser
.
add_argument
(
"--return-token-ids"
,
action
=
"store_true"
,
default
=
ServerArgs
.
return_token_ids
,
help
=
"Whether to return token IDs in the output, this may introduce additional overhead."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--load-format"
,
"--load-format"
,
type
=
str
,
type
=
str
,
...
...
test/srt/run_suite.py
View file @
35bdb485
...
@@ -44,6 +44,7 @@ suites = {
...
@@ -44,6 +44,7 @@ suites = {
"test_vision_chunked_prefill.py"
,
"test_vision_chunked_prefill.py"
,
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_session_control.py"
,
"test_session_control.py"
,
"test_engine_token_ids.py"
,
],
],
"nightly"
:
[
"nightly"
:
[
"test_nightly_gsm8k_eval.py"
,
"test_nightly_gsm8k_eval.py"
,
...
...
test/srt/test_engine_token_ids.py
0 → 100644
View file @
35bdb485
import
unittest
from
transformers
import
AutoTokenizer
import
sglang
as
sgl
class
TestEngineTokenIds
(
unittest
.
TestCase
):
def
test_token_ids_in_generate
(
self
):
llm
=
sgl
.
Engine
(
model_path
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
return_token_ids
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
{
"temperature"
:
0.8
,
"top_p"
:
0.95
}
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Hugging Face tokenizer has a start token in its output,
# while SGLang only adds next_token_id in output_ids.
# We remove start token in HF output for comparison.
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
hf_input_ids
=
tokenizer
.
encode
(
prompt
)
self
.
assertEqual
(
output
[
"input_ids"
],
hf_input_ids
,
f
"Input token IDs mismatch for:
{
prompt
}
"
,
)
hf_output_ids
=
tokenizer
.
encode
(
output
[
"text"
])[
1
:]
# remove start token
self
.
assertEqual
(
output
[
"output_ids"
],
hf_output_ids
,
f
"Output token IDs mismatch for:
{
output
[
'text'
]
}
"
,
)
self
.
assertEqual
(
len
(
output
[
"input_ids"
]),
output
[
"meta_info"
][
"prompt_tokens"
],
"Prompt token count mismatch"
,
)
self
.
assertEqual
(
len
(
output
[
"output_ids"
]),
output
[
"meta_info"
][
"completion_tokens"
],
"Completion token count mismatch"
,
)
llm
.
shutdown
()
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment