Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d8476818
Unverified
Commit
d8476818
authored
Aug 20, 2024
by
Juwan Yoo
Committed by
GitHub
Aug 20, 2024
Browse files
feat: allow streaming for multi-prompt and/or parallel sampling (#1134)
parent
df191254
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
210 additions
and
85 deletions
+210
-85
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+53
-40
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+81
-25
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+58
-17
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+18
-3
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
d8476818
...
...
@@ -153,9 +153,6 @@ class TokenizerManager:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
yield
response
else
:
if
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
yield
response
...
...
@@ -311,6 +308,7 @@ class TokenizerManager:
parallel_sample_num
=
1
# First send out all requests
generators
=
[]
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
if
j
==
0
and
parallel_sample_num
!=
1
:
...
...
@@ -371,42 +369,48 @@ class TokenizerManager:
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
# Then wait for all responses
generators
.
append
(
self
.
_wait_for_response
(
event
,
state
,
obj
,
rid
,
request
,
index
=
index
,
response_index
=
len
(
generators
),
)
)
# Then process the responses based on streaming option
is_stream
=
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
tasks
=
[
asyncio
.
create_task
(
gen
.
__anext__
())
for
gen
in
generators
]
output_list
=
[]
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
if
j
==
0
and
parallel_sample_num
!=
1
:
continue
index
=
i
*
parallel_sample_num
+
j
if
parallel_sample_num
!=
1
:
index
+=
batch_size
-
1
-
i
rid
=
obj
.
rid
[
index
]
state
=
self
.
rid_to_state
[
rid
]
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
if
self
.
is_generation
:
output_list
.
append
(
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
return_text_in_logprobs
,
)
while
tasks
:
done
,
_
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
gen_index
=
tasks
.
index
(
task
)
try
:
result
=
task
.
result
()
if
is_stream
:
yield
result
else
:
output_list
.
append
(
result
)
tasks
[
gen_index
]
=
asyncio
.
create_task
(
generators
[
gen_index
].
__anext__
()
)
else
:
output_list
.
append
(
state
.
out_list
[
-
1
])
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
yield
output_list
except
StopAsyncIteration
:
del
generators
[
gen_index
]
del
tasks
[
gen_index
]
if
not
is_stream
:
yield
output_list
def
_validate_input_length
(
self
,
input_ids
:
List
[
int
]):
if
len
(
input_ids
)
>=
self
.
context_len
:
...
...
@@ -437,26 +441,35 @@ class TokenizerManager:
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
rid
:
str
,
request
,
index
:
int
=
None
,
response_index
:
int
=
0
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
4
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
self
.
abort_request
(
rid
)
for
rid
in
[
obj
.
rid
]
if
obj
.
is_single
else
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
if
self
.
is_generation
:
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
,
obj
.
top_logprobs_num
,
obj
.
return_logprob
if
index
is
None
else
obj
.
return_logprob
[
index
],
(
obj
.
top_logprobs_num
if
index
is
None
else
obj
.
top_logprobs_num
[
index
]
),
obj
.
return_text_in_logprobs
,
)
else
:
# isinstance(obj, EmbeddingReqInput)
out
=
state
.
out_list
[
-
1
]
out
[
"index"
]
=
response_index
# Log requests
if
self
.
server_args
.
log_requests
and
state
.
finished
:
if
obj
.
text
is
None
:
...
...
python/sglang/srt/openai_api/adapter.py
View file @
d8476818
...
...
@@ -277,6 +277,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
request_data
=
json
.
loads
(
line
)
file_request_list
.
append
(
request_data
)
body
=
request_data
[
"body"
]
# Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request).
if
body
.
get
(
"stream"
,
False
):
raise
ValueError
(
"Streaming requests are not supported in batch mode"
)
if
end_point
==
"/v1/chat/completions"
:
all_requests
.
append
(
ChatCompletionRequest
(
**
body
))
elif
end_point
==
"/v1/completions"
:
...
...
@@ -592,27 +598,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if
adapted_request
.
stream
:
async
def
generate_stream_resp
():
stream_buffer
=
""
n_prev_token
=
0
stream_buffers
=
{}
n_prev_tokens
=
{}
prompt_tokens
=
{}
completion_tokens
=
{}
try
:
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
index
=
content
[
"index"
]
stream_buffer
=
stream_buffers
.
get
(
index
,
""
)
n_prev_token
=
n_prev_tokens
.
get
(
index
,
0
)
text
=
content
[
"text"
]
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
prompt_tokens
[
index
]
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
[
index
]
=
content
[
"meta_info"
][
"completion_tokens"
]
if
not
stream_buffer
:
# The first chunk
if
request
.
echo
:
if
isinstance
(
request
.
prompt
,
str
):
# for the case of single str prompts
prompts
=
request
.
prompt
elif
isinstance
(
request
.
prompt
,
list
)
and
isinstance
(
request
.
prompt
[
0
],
int
):
prompts
=
tokenizer_manager
.
tokenizer
.
decode
(
request
.
prompt
,
skip_special_tokens
=
True
)
elif
isinstance
(
request
.
prompt
,
list
):
if
isinstance
(
request
.
prompt
[
0
],
str
):
# for the case of multiple str prompts
prompts
=
request
.
prompt
[
index
//
request
.
n
]
elif
isinstance
(
request
.
prompt
[
0
],
int
):
# for the case of single token ids prompt
prompts
=
tokenizer_manager
.
tokenizer
.
decode
(
request
.
prompt
,
skip_special_tokens
=
True
)
elif
isinstance
(
request
.
prompt
[
0
],
list
)
and
isinstance
(
request
.
prompt
[
0
][
0
],
int
):
# for the case of multiple token ids prompts
prompts
=
tokenizer_manager
.
tokenizer
.
decode
(
request
.
prompt
[
index
//
request
.
n
],
skip_special_tokens
=
True
,
)
# Prepend prompt in response text.
text
=
prompts
+
text
...
...
@@ -649,7 +673,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
stream_buffer
+
delta
choice_data
=
CompletionResponseStreamChoice
(
index
=
0
,
index
=
index
,
text
=
delta
,
logprobs
=
logprobs
,
finish_reason
=
format_finish_reason
(
...
...
@@ -662,12 +686,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
choices
=
[
choice_data
],
model
=
request
.
model
,
)
stream_buffers
[
index
]
=
stream_buffer
n_prev_tokens
[
index
]
=
n_prev_token
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
:
total_prompt_tokens
=
sum
(
tokens
for
i
,
tokens
in
prompt_tokens
.
items
()
if
i
%
request
.
n
==
0
)
total_completion_tokens
=
sum
(
tokens
for
tokens
in
completion_tokens
.
values
()
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
prompt_tokens
=
total_
prompt_tokens
,
completion_tokens
=
total_
completion_tokens
,
total_tokens
=
total_
prompt_tokens
+
total_
completion_tokens
,
)
final_usage_chunk
=
CompletionStreamResponse
(
...
...
@@ -914,16 +950,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if
adapted_request
.
stream
:
async
def
generate_stream_resp
():
is_first
=
True
stream_buffer
=
""
n_prev_token
=
0
is_firsts
=
{}
stream_buffers
=
{}
n_prev_tokens
=
{}
prompt_tokens
=
{}
completion_tokens
=
{}
try
:
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
index
=
content
[
"index"
]
is_first
=
is_firsts
.
get
(
index
,
True
)
stream_buffer
=
stream_buffers
.
get
(
index
,
""
)
n_prev_token
=
n_prev_tokens
.
get
(
index
,
0
)
prompt_tokens
[
index
]
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
[
index
]
=
content
[
"meta_info"
][
"completion_tokens"
]
if
request
.
logprobs
:
logprobs
=
to_openai_style_logprobs
(
output_token_logprobs
=
content
[
"meta_info"
][
...
...
@@ -973,7 +1016,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
# First chunk with role
is_first
=
False
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
0
,
index
=
index
,
delta
=
DeltaMessage
(
role
=
"assistant"
),
finish_reason
=
format_finish_reason
(
content
[
"meta_info"
][
"finish_reason"
]
...
...
@@ -991,7 +1034,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
stream_buffer
+
delta
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
0
,
index
=
index
,
delta
=
DeltaMessage
(
content
=
delta
),
finish_reason
=
format_finish_reason
(
content
[
"meta_info"
][
"finish_reason"
]
...
...
@@ -1003,12 +1046,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choices
=
[
choice_data
],
model
=
request
.
model
,
)
is_firsts
[
index
]
=
is_first
stream_buffers
[
index
]
=
stream_buffer
n_prev_tokens
[
index
]
=
n_prev_token
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
:
total_prompt_tokens
=
sum
(
tokens
for
i
,
tokens
in
prompt_tokens
.
items
()
if
i
%
request
.
n
==
0
)
total_completion_tokens
=
sum
(
tokens
for
tokens
in
completion_tokens
.
values
()
)
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
prompt_tokens
=
total_
prompt_tokens
,
completion_tokens
=
total_
completion_tokens
,
total_tokens
=
total_
prompt_tokens
+
total_
completion_tokens
,
)
final_usage_chunk
=
ChatCompletionStreamResponse
(
...
...
test/srt/test_openai_server.py
View file @
d8476818
...
...
@@ -85,13 +85,26 @@ class TestOpenAIServer(unittest.TestCase):
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
run_completion_stream
(
self
,
echo
,
logprobs
,
token_input
):
def
run_completion_stream
(
self
,
echo
,
logprobs
,
use_list_input
,
parallel_sample_num
,
token_input
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
prompt
=
"The capital of France is"
if
token_input
:
prompt_arg
=
self
.
tokenizer
.
encode
(
prompt
)
prompt_input
=
self
.
tokenizer
.
encode
(
prompt
)
num_prompt_tokens
=
len
(
prompt_input
)
else
:
prompt_arg
=
prompt
prompt_input
=
prompt
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
if
use_list_input
:
prompt_arg
=
[
prompt_input
,
prompt_input
]
num_choices
=
len
(
prompt_arg
)
num_prompt_tokens
*=
2
else
:
prompt_arg
=
prompt_input
num_choices
=
1
generator
=
client
.
completions
.
create
(
model
=
self
.
model
,
prompt
=
prompt_arg
,
...
...
@@ -101,9 +114,10 @@ class TestOpenAIServer(unittest.TestCase):
logprobs
=
logprobs
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
n
=
parallel_sample_num
,
)
first
=
True
is_
first
s
=
{}
for
response
in
generator
:
usage
=
response
.
usage
if
usage
is
not
None
:
...
...
@@ -111,10 +125,14 @@ class TestOpenAIServer(unittest.TestCase):
assert
usage
.
completion_tokens
>
0
assert
usage
.
total_tokens
>
0
continue
index
=
response
.
choices
[
0
].
index
is_first
=
is_firsts
.
get
(
index
,
True
)
if
logprobs
:
assert
response
.
choices
[
0
].
logprobs
assert
isinstance
(
response
.
choices
[
0
].
logprobs
.
tokens
[
0
],
str
)
if
not
(
first
and
echo
):
if
not
(
is_
first
and
echo
):
assert
isinstance
(
response
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
],
dict
)
...
...
@@ -125,15 +143,20 @@ class TestOpenAIServer(unittest.TestCase):
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert
ret_num_top_logprobs
>
0
if
first
:
if
is_
first
:
if
echo
:
assert
response
.
choices
[
0
].
text
.
startswith
(
prompt
),
f
"
{
response
.
choices
[
0
].
text
}
and all args
{
echo
}
{
logprobs
}
{
token_input
}
{
first
}
"
first
=
False
),
f
"
{
response
.
choices
[
0
].
text
}
and all args
{
echo
}
{
logprobs
}
{
token_input
}
{
is_
first
}
"
is_
first
s
[
index
]
=
False
assert
response
.
id
assert
response
.
created
for
index
in
[
i
for
i
in
range
(
parallel_sample_num
*
num_choices
)]:
assert
not
is_firsts
.
get
(
index
,
True
),
f
"index
{
index
}
is not found in the response"
def
run_chat_completion
(
self
,
logprobs
,
parallel_sample_num
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
response
=
client
.
chat
.
completions
.
create
(
...
...
@@ -172,7 +195,7 @@ class TestOpenAIServer(unittest.TestCase):
assert
response
.
usage
.
completion_tokens
>
0
assert
response
.
usage
.
total_tokens
>
0
def
run_chat_completion_stream
(
self
,
logprobs
):
def
run_chat_completion_stream
(
self
,
logprobs
,
parallel_sample_num
=
1
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
generator
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
...
...
@@ -185,9 +208,10 @@ class TestOpenAIServer(unittest.TestCase):
top_logprobs
=
logprobs
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
n
=
parallel_sample_num
,
)
is_first
=
True
is_first
s
=
{}
for
response
in
generator
:
usage
=
response
.
usage
if
usage
is
not
None
:
...
...
@@ -196,11 +220,12 @@ class TestOpenAIServer(unittest.TestCase):
assert
usage
.
total_tokens
>
0
continue
index
=
response
.
choices
[
0
].
index
data
=
response
.
choices
[
0
].
delta
if
is_first
:
data
.
role
==
"assistant"
is_first
=
False
if
is_first
s
.
get
(
index
,
True
)
:
assert
data
.
role
==
"assistant"
is_first
s
[
index
]
=
False
continue
if
logprobs
:
...
...
@@ -222,6 +247,11 @@ class TestOpenAIServer(unittest.TestCase):
assert
response
.
id
assert
response
.
created
for
index
in
[
i
for
i
in
range
(
parallel_sample_num
)]:
assert
not
is_firsts
.
get
(
index
,
True
),
f
"index
{
index
}
is not found in the response"
def
run_batch
(
self
,
mode
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
if
mode
==
"completion"
:
...
...
@@ -320,7 +350,9 @@ class TestOpenAIServer(unittest.TestCase):
f
"Batch job status:
{
batch_job
.
status
}
...trying again in 3 seconds..."
)
batch_job
=
client
.
batches
.
retrieve
(
batch_job
.
id
)
assert
batch_job
.
status
==
"completed"
assert
(
batch_job
.
status
==
"completed"
),
f
"Batch job status is not completed:
{
batch_job
.
status
}
"
assert
batch_job
.
request_counts
.
completed
==
len
(
content
)
assert
batch_job
.
request_counts
.
failed
==
0
assert
batch_job
.
request_counts
.
total
==
len
(
content
)
...
...
@@ -353,8 +385,16 @@ class TestOpenAIServer(unittest.TestCase):
# parallel sampling adn list input are not supported in streaming mode
for
echo
in
[
False
,
True
]:
for
logprobs
in
[
None
,
5
]:
for
token_input
in
[
False
,
True
]:
self
.
run_completion_stream
(
echo
,
logprobs
,
token_input
)
for
use_list_input
in
[
True
,
False
]:
for
parallel_sample_num
in
[
1
,
2
]:
for
token_input
in
[
False
,
True
]:
self
.
run_completion_stream
(
echo
,
logprobs
,
use_list_input
,
parallel_sample_num
,
token_input
,
)
def
test_chat_completion
(
self
):
for
logprobs
in
[
None
,
5
]:
...
...
@@ -363,7 +403,8 @@ class TestOpenAIServer(unittest.TestCase):
def
test_chat_completion_stream
(
self
):
for
logprobs
in
[
None
,
5
]:
self
.
run_chat_completion_stream
(
logprobs
)
for
parallel_sample_num
in
[
1
,
2
]:
self
.
run_chat_completion_stream
(
logprobs
,
parallel_sample_num
)
def
test_batch
(
self
):
for
mode
in
[
"completion"
,
"chat"
]:
...
...
test/srt/test_srt_endpoint.py
View file @
d8476818
...
...
@@ -23,7 +23,12 @@ class TestSRTEndpoint(unittest.TestCase):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
return_text
=
False
,
n
=
1
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
return_text
=
False
,
n
=
1
,
stream
=
False
,
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
...
...
@@ -34,14 +39,21 @@ class TestSRTEndpoint(unittest.TestCase):
"max_new_tokens"
:
32
,
"n"
:
n
,
},
"stream"
:
False
,
"stream"
:
stream
,
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
"return_text_in_logprobs"
:
return_text
,
"logprob_start_len"
:
0
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
if
not
stream
:
response_json
=
response
.
json
()
else
:
response_json
=
[]
for
line
in
response
.
iter_lines
():
if
line
.
startswith
(
b
"data: "
)
and
line
[
6
:]
!=
b
"[DONE]"
:
response_json
.
append
(
json
.
loads
(
line
[
6
:]))
print
(
json
.
dumps
(
response_json
))
print
(
"="
*
100
)
def
test_simple_decode
(
self
):
...
...
@@ -50,6 +62,9 @@ class TestSRTEndpoint(unittest.TestCase):
def
test_parallel_sample
(
self
):
self
.
run_decode
(
n
=
3
)
def
test_parallel_sample_stream
(
self
):
self
.
run_decode
(
n
=
3
,
stream
=
True
)
def
test_logprob
(
self
):
for
top_logprobs_num
in
[
0
,
3
]:
for
return_text
in
[
True
,
False
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment