Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
kecinstone
2024pra-vllm
Commits
a490aafa
Unverified
Commit
a490aafa
authored
Apr 06, 2023
by
Zhuohan Li
Committed by
GitHub
Apr 06, 2023
Browse files
Fix potential bugs in FastAPI frontend and add comments (#28)
parent
12659a0b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
5 deletions
+30
-5
cacheflow/http_frontend/fastapi_frontend.py
cacheflow/http_frontend/fastapi_frontend.py
+30
-5
No files found.
cacheflow/http_frontend/fastapi_frontend.py
View file @
a490aafa
...
...
@@ -17,8 +17,10 @@ from cacheflow.master.server import (Server, add_server_arguments,
from
cacheflow.worker.controller
import
DeviceID
from
cacheflow.utils
import
Counter
,
get_gpu_memory
,
get_cpu_memory
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
class
FastAPIFrontend
:
def
__init__
(
self
,
...
...
@@ -30,7 +32,7 @@ class FastAPIFrontend:
dtype
:
str
,
seed
:
int
,
swap_space
:
int
,
max_batch
_size
:
int
,
max_
num_
batch
ed_tokens
:
int
,
num_nodes
:
int
,
num_devices_per_node
:
int
,
distributed_init_method
:
str
,
...
...
@@ -51,7 +53,7 @@ class FastAPIFrontend:
dtype
=
dtype
,
seed
=
seed
,
swap_space
=
swap_space
,
max_batch
_size
=
max_batch
_size
,
max_
num_
batch
ed_tokens
=
max_
num_
batch
ed_tokens
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
...
...
@@ -68,12 +70,14 @@ class FastAPIFrontend:
self
.
is_server_running
=
True
updated_seq_groups
=
await
self
.
server
.
step
.
remote
()
self
.
is_server_running
=
False
# Notify the waiting coroutines that there new outputs ready.
for
seq_group
in
updated_seq_groups
:
group_id
=
seq_group
.
group_id
self
.
running_seq_groups
[
group_id
]
=
seq_group
self
.
sequence_group_events
[
group_id
].
set
()
async
def
generate
(
self
,
request_dict
:
Dict
):
# Preprocess the request.
prompt
=
request_dict
[
"prompt"
]
sampling_params
=
SamplingParams
.
from_dict
(
request_dict
)
sampling_params
.
stop_token_ids
.
add
(
self
.
tokenizer
.
eos_token_id
)
...
...
@@ -87,15 +91,27 @@ class FastAPIFrontend:
arrival_time
=
time
.
time
()
group_id
=
next
(
self
.
seq_group_counter
)
seq_group
=
SequenceGroup
(
group_id
,
seqs
,
arrival_time
)
# Create an event to notify us that there is new output from the
# cacheflow server.
group_event
=
asyncio
.
Event
()
self
.
running_seq_groups
[
group_id
]
=
seq_group
self
.
sequence_group_events
[
group_id
]
=
group_event
# Add the request into the cacheflow server's waiting queue.
await
self
.
server
.
add_sequence_groups
.
remote
([(
seq_group
,
sampling_params
)])
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while
True
:
# Kick the server if the server is not running.
if
not
self
.
is_server_running
:
await
self
.
server_step
()
# Wait for new output. Add a 1s timeout to prevent dead lock.
await
asyncio
.
wait_for
(
group_event
.
wait
(),
timeout
=
1
)
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
await
asyncio
.
wait_for
(
group_event
.
wait
(),
timeout
=
TIMEOUT_TO_PREVENT_DEADLOCK
)
# Reset the event to wait for the next output.
group_event
.
clear
()
# Decode and return new outputs
seq_group
=
self
.
running_seq_groups
[
group_id
]
all_outputs
=
[]
for
seq
in
seq_group
.
seqs
:
...
...
@@ -107,7 +123,16 @@ class FastAPIFrontend:
"error"
:
0
,
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
# Once finished, release the resources of the sequence group.
if
seq_group
.
is_finished
():
del
self
.
running_seq_groups
[
group_id
]
del
self
.
sequence_group_events
[
group_id
]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# queue to be executed.
if
not
self
.
is_server_running
:
await
self
.
server_step
()
break
...
...
@@ -143,7 +168,7 @@ if __name__ == "__main__":
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
max_batch
_size
=
args
.
max_batch
_size
,
max_
num_
batch
ed_tokens
=
args
.
max_
num_
batch
ed_tokens
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment