Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6fcd6d7d
"vscode:/vscode.git/clone" did not exist on "4a38166afeea498e8d333d40264e778ac5b16d81"
Unverified
Commit
6fcd6d7d
authored
Oct 27, 2024
by
Byron Hsu
Committed by
GitHub
Oct 27, 2024
Browse files
Support token ids in `engine.generate` (#1820)
parent
c77762d5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
4 deletions
+72
-4
examples/runtime/engine/input_ids.py
examples/runtime/engine/input_ids.py
+39
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+10
-4
test/srt/test_srt_engine.py
test/srt/test_srt_engine.py
+23
-0
No files found.
examples/runtime/engine/input_ids.py
0 → 100644
View file @
6fcd6d7d
"""
This example demonstrates how to provide tokenized ids as input instead of text prompt
"""
import
sglang
as
sgl
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
MODEL_PATH
=
"meta-llama/Llama-3.1-8B-Instruct"
def
main
():
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
{
"temperature"
:
0.8
,
"top_p"
:
0.95
}
# Tokenize inputs
tokenizer
=
get_tokenizer
(
MODEL_PATH
)
token_ids_list
=
[
tokenizer
.
encode
(
prompt
)
for
prompt
in
prompts
]
# Create an LLM.
# You can also specify `skip_tokenizer_init=True`, but it requires explicit detokenization at the end
llm
=
sgl
.
Engine
(
model_path
=
MODEL_PATH
)
outputs
=
llm
.
generate
(
input_ids
=
token_ids_list
,
sampling_params
=
sampling_params
)
# Print the outputs.
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
print
(
"==============================="
)
print
(
f
"Prompt:
{
prompt
}
\n
Generated Text:
{
output
[
'text'
]
}
"
)
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
python/sglang/srt/server.py
View file @
6fcd6d7d
...
...
@@ -742,18 +742,20 @@ class Engine:
def
generate
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
# The input prompt. It can be a single prompt or a batch of prompts.
prompt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
sampling_params
:
Optional
[
Dict
]
=
None
,
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
stream
:
bool
=
False
,
):
# TODO (ByronHsu): refactor to reduce the duplicated code
obj
=
GenerateReqInput
(
text
=
prompt
,
input_ids
=
input_ids
,
sampling_params
=
sampling_params
,
return_logprob
=
return_logprob
,
logprob_start_len
=
logprob_start_len
,
...
...
@@ -791,8 +793,11 @@ class Engine:
async
def
async_generate
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
# The input prompt. It can be a single prompt or a batch of prompts.
prompt
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
sampling_params
:
Optional
[
Dict
]
=
None
,
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
...
...
@@ -801,6 +806,7 @@ class Engine:
):
obj
=
GenerateReqInput
(
text
=
prompt
,
input_ids
=
input_ids
,
sampling_params
=
sampling_params
,
return_logprob
=
return_logprob
,
logprob_start_len
=
logprob_start_len
,
...
...
test/srt/test_srt_engine.py
View file @
6fcd6d7d
...
...
@@ -9,6 +9,7 @@ import unittest
from
types
import
SimpleNamespace
import
sglang
as
sgl
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.test.few_shot_gsm8k_engine
import
run_eval
from
sglang.test.test_utils
import
DEFAULT_MODEL_NAME_FOR_TEST
...
...
@@ -106,6 +107,28 @@ class TestSRTEngine(unittest.TestCase):
metrics
=
run_eval
(
args
)
assert
metrics
[
"accuracy"
]
>
0.7
def
test_5_prompt_input_ids_consistency
(
self
):
prompt
=
"The capital of UK is"
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
engine
=
sgl
.
Engine
(
model_path
=
model_path
,
random_seed
=
42
,
log_level
=
"error"
)
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
out1
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
tokenizer
=
get_tokenizer
(
model_path
)
token_ids
=
tokenizer
.
encode
(
prompt
)
out2
=
engine
.
generate
(
input_ids
=
token_ids
,
sampling_params
=
sampling_params
)[
"text"
]
engine
.
shutdown
()
print
(
"==== Answer 1 ===="
)
print
(
out1
)
print
(
"==== Answer 2 ===="
)
print
(
out2
)
assert
out1
==
out2
,
f
"
{
out1
}
!=
{
out2
}
"
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment