Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a490aafa
Unverified
Commit
a490aafa
authored
Apr 06, 2023
by
Zhuohan Li
Committed by
GitHub
Apr 06, 2023
Browse files
Fix potential bugs in FastAPI frontend and add comments (#28)
parent
12659a0b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
5 deletions
+30
-5
cacheflow/http_frontend/fastapi_frontend.py
cacheflow/http_frontend/fastapi_frontend.py
+30
-5
No files found.
cacheflow/http_frontend/fastapi_frontend.py
View file @
a490aafa
...
...
@@ -17,8 +17,10 @@ from cacheflow.master.server import (Server, add_server_arguments,
from
cacheflow.worker.controller
import
DeviceID
from
cacheflow.utils
import
Counter
,
get_gpu_memory
,
get_cpu_memory
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
class
FastAPIFrontend
:
def
__init__
(
self
,
...
...
@@ -30,7 +32,7 @@ class FastAPIFrontend:
dtype
:
str
,
seed
:
int
,
swap_space
:
int
,
max_batch
_size
:
int
,
max_
num_
batch
ed_tokens
:
int
,
num_nodes
:
int
,
num_devices_per_node
:
int
,
distributed_init_method
:
str
,
...
...
@@ -51,7 +53,7 @@ class FastAPIFrontend:
dtype
=
dtype
,
seed
=
seed
,
swap_space
=
swap_space
,
max_batch
_size
=
max_batch
_size
,
max_
num_
batch
ed_tokens
=
max_
num_
batch
ed_tokens
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
...
...
@@ -68,12 +70,14 @@ class FastAPIFrontend:
self
.
is_server_running
=
True
updated_seq_groups
=
await
self
.
server
.
step
.
remote
()
self
.
is_server_running
=
False
# Notify the waiting coroutines that there new outputs ready.
for
seq_group
in
updated_seq_groups
:
group_id
=
seq_group
.
group_id
self
.
running_seq_groups
[
group_id
]
=
seq_group
self
.
sequence_group_events
[
group_id
].
set
()
async
def
generate
(
self
,
request_dict
:
Dict
):
# Preprocess the request.
prompt
=
request_dict
[
"prompt"
]
sampling_params
=
SamplingParams
.
from_dict
(
request_dict
)
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
...
...
@@ -87,15 +91,27 @@ class FastAPIFrontend:
arrival_time
=
time
.
time
()
group_id
=
next
(
self
.
seq_group_counter
)
seq_group
=
SequenceGroup
(
group_id
,
seqs
,
arrival_time
)
# Create an event to notify us that there is new output from the
# cacheflow server.
group_event
=
asyncio
.
Event
()
self
.
running_seq_groups
[
group_id
]
=
seq_group
self
.
sequence_group_events
[
group_id
]
=
group_event
# Add the request into the cacheflow server's waiting queue.
await
self
.
server
.
add_sequence_groups
.
remote
([(
seq_group
,
sampling_params
)])
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while
True
:
# Kick the server if the server is not running.
if
not
self
.
is_server_running
:
await
self
.
server_step
()
# Wait for new output. Add a 1s timeout to prevent dead lock.
await
asyncio
.
wait_for
(
group_event
.
wait
(),
timeout
=
1
)
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
await
asyncio
.
wait_for
(
group_event
.
wait
(),
timeout
=
TIMEOUT_TO_PREVENT_DEADLOCK
)
# Reset the event to wait for the next output.
group_event
.
clear
()
# Decode and return new outputs
seq_group
=
self
.
running_seq_groups
[
group_id
]
all_outputs
=
[]
for
seq
in
seq_group
.
seqs
:
...
...
@@ -107,7 +123,16 @@ class FastAPIFrontend:
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
# Once finished, release the resources of the sequence group.
if
seq_group
.
is_finished
():
del
self
.
running_seq_groups
[
group_id
]
del
self
.
sequence_group_events
[
group_id
]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# queue to be executed.
if
not
self
.
is_server_running
:
await
self
.
server_step
()
break
...
...
@@ -143,7 +168,7 @@ if __name__ == "__main__":
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
max_batch
_size
=
args
.
max_batch
_size
,
max_
num_
batch
ed_tokens
=
args
.
max_
num_
batch
ed_tokens
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment