Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
5020e1e8
Unverified
Commit
5020e1e8
authored
Jun 11, 2023
by
Zhuohan Li
Committed by
GitHub
Jun 10, 2023
Browse files
Non-streaming simple fastapi server (#144)
parent
42983742
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
20 deletions
+61
-20
cacheflow/entrypoints/openai/openai_frontend.py
cacheflow/entrypoints/openai/openai_frontend.py
+1
-1
cacheflow/entrypoints/simple_fastapi_frontend.py
cacheflow/entrypoints/simple_fastapi_frontend.py
+29
-10
examples/simple_fastapi_client.py
examples/simple_fastapi_client.py
+31
-9
No files found.
cacheflow/entrypoints/openai/openai_frontend.py
View file @
5020e1e8
...
@@ -233,7 +233,7 @@ async def create_completion(raw_request: Request):
...
@@ -233,7 +233,7 @@ async def create_completion(raw_request: Request):
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
# Abort the request if the client disconnects.
await
server
.
abort
(
request
_id
)
await
abort
_
request
(
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"Client disconnected"
)
"Client disconnected"
)
final_res
=
res
final_res
=
res
...
...
cacheflow/entrypoints/simple_fastapi_frontend.py
View file @
5020e1e8
...
@@ -3,7 +3,7 @@ import json
...
@@ -3,7 +3,7 @@ import json
from
typing
import
AsyncGenerator
from
typing
import
AsyncGenerator
from
fastapi
import
BackgroundTasks
,
FastAPI
,
Request
from
fastapi
import
BackgroundTasks
,
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
Response
,
StreamingResponse
import
uvicorn
import
uvicorn
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
...
@@ -17,19 +17,22 @@ app = FastAPI()
...
@@ -17,19 +17,22 @@ app = FastAPI()
@
app
.
post
(
"/generate"
)
@
app
.
post
(
"/generate"
)
async
def
generate
_stream
(
request
:
Request
)
->
Streaming
Response
:
async
def
generate
(
request
:
Request
)
->
Response
:
""" Stream the results of the generation request.
""" Stream the results of the generation request.
The request should be a JSON object with the following fields:
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
"""
request_dict
=
await
request
.
json
()
request_dict
=
await
request
.
json
()
prompt
=
request_dict
.
pop
(
"prompt"
)
prompt
=
request_dict
.
pop
(
"prompt"
)
stream
=
request_dict
.
pop
(
"stream"
,
False
)
sampling_params
=
SamplingParams
(
**
request_dict
)
sampling_params
=
SamplingParams
(
**
request_dict
)
request_id
=
random_uuid
()
request_id
=
random_uuid
()
results_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
)
results_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
)
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
request_output
in
results_generator
:
async
for
request_output
in
results_generator
:
prompt
=
request_output
.
prompt
prompt
=
request_output
.
prompt
...
@@ -37,19 +40,35 @@ async def generate_stream(request: Request) -> StreamingResponse:
...
@@ -37,19 +40,35 @@ async def generate_stream(request: Request) -> StreamingResponse:
prompt
+
output
.
text
prompt
+
output
.
text
for
output
in
request_output
.
outputs
for
output
in
request_output
.
outputs
]
]
ret
=
{
ret
=
{
"text"
:
text_outputs
}
"text"
:
text_outputs
,
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
async
def
abort_request
()
->
None
:
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
await
server
.
abort
(
request_id
)
background_tasks
=
BackgroundTasks
()
if
stream
:
# Abort the request if the client disconnects.
background_tasks
=
BackgroundTasks
()
background_tasks
.
add_task
(
abort_request
)
# Abort the request if the client disconnects.
return
StreamingResponse
(
stream_results
(),
background
=
background_tasks
)
background_tasks
.
add_task
(
abort_request
)
return
StreamingResponse
(
stream_results
(),
background
=
background_tasks
)
# Non-streaming case
final_output
=
None
async
for
request_output
in
results_generator
:
if
await
request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
server
.
abort
(
request_id
)
return
Response
(
status_code
=
499
)
final_output
=
request_output
assert
final_output
is
not
None
prompt
=
final_output
.
prompt
text_outputs
=
[
prompt
+
output
.
text
for
output
in
final_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
}
return
Response
(
content
=
json
.
dumps
(
ret
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/simple_fastapi_client.py
View file @
5020e1e8
import
argparse
import
argparse
import
requests
import
json
import
json
import
requests
from
typing
import
Iterable
,
List
def
clear_line
(
n
=
1
)
:
def
clear_line
(
n
:
int
=
1
)
->
None
:
LINE_UP
=
'
\033
[1A'
LINE_UP
=
'
\033
[1A'
LINE_CLEAR
=
'
\x1b
[2K'
LINE_CLEAR
=
'
\x1b
[2K'
for
i
in
range
(
n
):
for
i
in
range
(
n
):
print
(
LINE_UP
,
end
=
LINE_CLEAR
,
flush
=
True
)
print
(
LINE_UP
,
end
=
LINE_CLEAR
,
flush
=
True
)
def
http_request
(
prompt
:
str
,
api_url
:
str
,
n
:
int
=
1
):
def
post_http_request
(
prompt
:
str
,
api_url
:
str
,
n
:
int
=
1
,
stream
:
bool
=
False
)
->
requests
.
Response
:
headers
=
{
"User-Agent"
:
"Test Client"
}
headers
=
{
"User-Agent"
:
"Test Client"
}
pload
=
{
pload
=
{
"prompt"
:
prompt
,
"prompt"
:
prompt
,
...
@@ -17,32 +19,52 @@ def http_request(prompt: str, api_url: str, n: int = 1):
...
@@ -17,32 +19,52 @@ def http_request(prompt: str, api_url: str, n: int = 1):
"use_beam_search"
:
True
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
,
"temperature"
:
0.0
,
"max_tokens"
:
16
,
"max_tokens"
:
16
,
"stream"
:
stream
,
}
}
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
return
response
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
def
get_streaming_response
(
response
:
requests
.
Response
)
->
Iterable
[
List
[
str
]]:
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
output
=
data
[
"text"
]
yield
output
yield
output
def
get_response
(
response
:
requests
.
Response
)
->
List
[
str
]:
data
=
json
.
loads
(
response
.
content
)
output
=
data
[
"text"
]
return
output
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
"San Francisco is a"
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
"San Francisco is a"
)
parser
.
add_argument
(
"--stream"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
prompt
=
args
.
prompt
prompt
=
args
.
prompt
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
n
=
args
.
n
n
=
args
.
n
stream
=
args
.
stream
print
(
f
"Prompt:
{
prompt
}
\n
"
,
flush
=
True
)
print
(
f
"Prompt:
{
prompt
}
\n
"
,
flush
=
True
)
num_printed_lines
=
0
response
=
post_http_request
(
prompt
,
api_url
,
n
,
stream
)
for
h
in
http_request
(
prompt
,
api_url
,
n
):
clear_line
(
num_printed_lines
)
if
stream
:
num_printed_lines
=
0
num_printed_lines
=
0
for
i
,
line
in
enumerate
(
h
):
for
h
in
get_streaming_response
(
response
):
num_printed_lines
+=
1
clear_line
(
num_printed_lines
)
num_printed_lines
=
0
for
i
,
line
in
enumerate
(
h
):
num_printed_lines
+=
1
print
(
f
"Beam candidate
{
i
}
:
{
line
}
"
,
flush
=
True
)
else
:
output
=
get_response
(
response
)
for
i
,
line
in
enumerate
(
output
):
print
(
f
"Beam candidate
{
i
}
:
{
line
}
"
,
flush
=
True
)
print
(
f
"Beam candidate
{
i
}
:
{
line
}
"
,
flush
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment