Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
04c0b214
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "9e426466af9595b6289e0e1bd2a269e42b989fdb"
Unverified
Commit
04c0b214
authored
May 12, 2024
by
Shannon Shen
Committed by
GitHub
May 12, 2024
Browse files
Allow `input_ids` in the input of the `/generate` endpoint (#363)
parent
6e09cf6a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
10 deletions
+43
-10
benchmark/latency_throughput/test_latency.py
benchmark/latency_throughput/test_latency.py
+1
-1
docs/sampling_params.md
docs/sampling_params.md
+2
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+15
-3
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+3
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+22
-6
No files found.
benchmark/latency_throughput/test_latency.py
View file @
04c0b214
...
...
@@ -30,7 +30,7 @@ if __name__ == "__main__":
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"
text"
:
f
"
{
a
}
, "
,
"
input_ids"
:
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
...
...
docs/sampling_params.md
View file @
04c0b214
...
...
@@ -8,6 +8,8 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
class
GenerateReqInput
:
# The input prompt
text
:
Union
[
List
[
str
],
str
]
# The token ids for text; one can either specify text or input_ids
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The image input
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The sampling_params
...
...
python/sglang/srt/managers/io_struct.py
View file @
04c0b214
...
...
@@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams
@
dataclass
class
GenerateReqInput
:
# The input prompt
text
:
Union
[
List
[
str
],
str
]
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can either specify text or input_ids
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The image input
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The sampling_params
...
...
@@ -28,7 +30,17 @@ class GenerateReqInput:
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def
post_init
(
self
):
is_single
=
isinstance
(
self
.
text
,
str
)
if
self
.
text
is
None
:
assert
self
.
input_ids
is
not
None
,
"Either text or input_ids should be provided"
else
:
assert
self
.
input_ids
is
None
,
"Either text or input_ids should be provided"
if
self
.
text
is
not
None
:
is_single
=
isinstance
(
self
.
text
,
str
)
else
:
is_single
=
isinstance
(
self
.
input_ids
[
0
],
int
)
self
.
is_single
=
is_single
if
is_single
:
if
self
.
sampling_params
is
None
:
...
...
@@ -42,7 +54,7 @@ class GenerateReqInput:
if
self
.
top_logprobs_num
is
None
:
self
.
top_logprobs_num
=
0
else
:
num
=
len
(
self
.
text
)
num
=
len
(
self
.
text
)
if
self
.
text
is
not
None
else
len
(
self
.
input_ids
)
if
self
.
image_data
is
None
:
self
.
image_data
=
[
None
]
*
num
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
04c0b214
...
...
@@ -85,6 +85,9 @@ class Req:
)
if
first_token
.
startswith
(
"▁"
):
old_output_str
=
" "
+
old_output_str
if
self
.
input_text
is
None
:
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
self
.
input_text
=
self
.
tokenizer
.
decode
(
self
.
input_ids
)
new_input_string
=
(
self
.
input_text
+
self
.
output_and_jump_forward_str
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
04c0b214
...
...
@@ -147,11 +147,15 @@ class TokenizerManager:
if
self
.
to_create_loop
:
await
self
.
create_handle_loop
()
is_single
=
isinstance
(
obj
.
text
,
str
)
is_single
=
obj
.
is_single
if
is_single
:
rid
=
obj
.
rid
input_ids
=
self
.
tokenizer
.
encode
(
obj
.
text
)
if
obj
.
input_ids
is
None
:
input_ids
=
self
.
tokenizer
.
encode
(
obj
.
text
)
else
:
input_ids
=
obj
.
input_ids
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
if
sampling_params
.
max_new_tokens
!=
0
:
sampling_params
.
normalize
(
self
.
tokenizer
)
...
...
@@ -204,10 +208,22 @@ class TokenizerManager:
event
.
clear
()
else
:
assert
obj
.
stream
is
False
bs
=
len
(
obj
.
text
)
if
obj
.
input_ids
is
None
:
bs
=
len
(
obj
.
text
)
else
:
bs
=
len
(
obj
.
input_ids
)
for
i
in
range
(
bs
):
rid
=
obj
.
rid
[
i
]
input_ids
=
self
.
tokenizer
.
encode
(
obj
.
text
[
i
])
if
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
i
]
input_ids
=
self
.
tokenizer
.
encode
(
obj
.
text
[
i
])
else
:
input_text
=
None
input_ids
=
obj
.
input_ids
[
i
]
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
i
])
if
sampling_params
.
max_new_tokens
!=
0
:
sampling_params
.
normalize
(
self
.
tokenizer
)
...
...
@@ -220,7 +236,7 @@ class TokenizerManager:
)
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
=
rid
,
input_text
=
obj
.
text
[
i
]
,
input_text
=
input_
text
,
input_ids
=
input_ids
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment