Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1a956e13
Unverified
Commit
1a956e13
authored
Jun 05, 2023
by
Zhuohan Li
Committed by
GitHub
Jun 05, 2023
Browse files
Fix various issues of async servers (#135)
parent
8274ca23
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
289 additions
and
121 deletions
+289
-121
benchmarks/benchmark_async_llm_server.py
benchmarks/benchmark_async_llm_server.py
+58
-0
cacheflow/config.py
cacheflow/config.py
+3
-3
cacheflow/core/block_manager.py
cacheflow/core/block_manager.py
+6
-3
cacheflow/core/scheduler.py
cacheflow/core/scheduler.py
+13
-1
cacheflow/entrypoints/openai/openai_frontend.py
cacheflow/entrypoints/openai/openai_frontend.py
+23
-7
cacheflow/entrypoints/simple_fastapi_frontend.py
cacheflow/entrypoints/simple_fastapi_frontend.py
+17
-8
cacheflow/sequence.py
cacheflow/sequence.py
+9
-1
cacheflow/server/arg_utils.py
cacheflow/server/arg_utils.py
+72
-59
cacheflow/server/async_llm_server.py
cacheflow/server/async_llm_server.py
+71
-22
cacheflow/server/llm_server.py
cacheflow/server/llm_server.py
+7
-9
cacheflow/server/ray_utils.py
cacheflow/server/ray_utils.py
+10
-8
No files found.
benchmarks/benchmark_async_llm_server.py
0 → 100644
View file @
1a956e13
import
argparse
import
json
import
threading
import
time
import
requests
def
main
(
args
:
argparse
.
Namespace
):
prompts
=
[
f
"Tell me a story with more than
{
''
.
join
([
str
(
i
+
1
)]
*
5
)
}
words"
for
i
in
range
(
args
.
n_threads
)]
headers
=
{
"User-Agent"
:
"CacheFlow Benchmark Client"
}
ploads
=
[{
"prompt"
:
p
,
"max_tokens"
:
args
.
max_tokens
,
"temperature"
:
0.0
,
"ignore_eos"
:
True
,
}
for
p
in
prompts
]
def
send_request
(
results
,
i
):
response
=
requests
.
post
(
args
.
api_url
,
headers
=
headers
,
json
=
ploads
[
i
],
stream
=
True
)
results
[
i
]
=
response
# use args.n_threads to prompt the backend
tik
=
time
.
time
()
threads
=
[]
results
=
[
None
]
*
args
.
n_threads
for
i
in
range
(
args
.
n_threads
):
t
=
threading
.
Thread
(
target
=
send_request
,
args
=
(
results
,
i
))
t
.
start
()
threads
.
append
(
t
)
for
t
in
threads
:
t
.
join
()
print
(
f
"Time (POST):
{
time
.
time
()
-
tik
}
s"
)
n_words
=
0
for
i
,
response
in
enumerate
(
results
):
k
=
list
(
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
))
response_new_words
=
json
.
loads
(
k
[
-
2
].
decode
(
"utf-8"
))[
"text"
][
0
]
n_words
+=
len
(
response_new_words
.
split
(
" "
))
-
len
(
prompts
[
i
].
split
(
" "
))
time_seconds
=
time
.
time
()
-
tik
print
(
f
"Time (total):
{
time_seconds
:.
3
f
}
s to finish, n_threads:
{
args
.
n_threads
}
, "
f
"throughput:
{
n_words
/
time_seconds
}
words/s."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--api-url"
,
type
=
str
,
default
=
"http://localhost:8001/generate"
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--n-threads"
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
()
main
(
args
)
cacheflow/config.py
View file @
1a956e13
...
...
@@ -116,15 +116,15 @@ class ParallelConfig:
self
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
use_ray
:
bool
,
worker_
use_ray
:
bool
,
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
use_ray
=
use_ray
self
.
worker_
use_ray
=
worker_
use_ray
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
if
self
.
world_size
>
1
:
self
.
use_ray
=
True
self
.
worker_
use_ray
=
True
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
...
...
cacheflow/core/block_manager.py
View file @
1a956e13
...
...
@@ -148,7 +148,7 @@ class BlockSpaceManager:
# the sequences in the same group.
blocks
:
Set
[
PhysicalTokenBlock
]
=
set
()
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
for
block
in
block_table
:
...
...
@@ -169,7 +169,7 @@ class BlockSpaceManager:
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
...
@@ -200,7 +200,7 @@ class BlockSpaceManager:
# GPU block -> CPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
if
S
eq
uenceStatus
.
is_finished
(
seq
.
status
):
if
s
eq
.
is_finished
():
continue
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
...
@@ -231,6 +231,9 @@ class BlockSpaceManager:
self
.
cpu_allocator
.
free
(
block
)
def
free
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
# Already freed or haven't been scheduled yet.
return
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
self
.
_free_block_table
(
block_table
)
del
self
.
block_tables
[
seq
.
seq_id
]
...
...
cacheflow/core/scheduler.py
View file @
1a956e13
...
...
@@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
logger
=
init_logger
(
__name__
)
_LOGGING_INTERVAL_SEC
=
10
_LOGGING_INTERVAL_SEC
=
5
class
PreemptionMode
(
enum
.
Enum
):
...
...
@@ -84,6 +84,18 @@ class Scheduler:
# Add sequence groups to the waiting queue.
self
.
waiting
.
append
(
seq_group
)
def
abort_seq_group
(
self
,
request_id
:
str
)
->
None
:
for
state_queue
in
[
self
.
waiting
,
self
.
running
,
self
.
swapped
]:
for
seq_group
in
state_queue
:
if
seq_group
.
request_id
==
request_id
:
# Remove the sequence group from the state queue.
state_queue
.
remove
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
if
seq
.
is_finished
():
continue
self
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_ABORTED
)
return
def
has_unfinished_seqs
(
self
)
->
bool
:
return
self
.
waiting
or
self
.
running
or
self
.
swapped
...
...
cacheflow/entrypoints/openai/openai_frontend.py
View file @
1a956e13
...
...
@@ -7,13 +7,14 @@ import time
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
import
fastapi
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
import
uvicorn
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.server.tokenizer_utils
import
get_tokenizer
from
cacheflow.logger
import
init_logger
...
...
@@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
UsageInfo
,
)
TIMEOUT_KEEP_ALIVE
=
5
# seconds
logger
=
init_logger
(
__name__
)
served_model
=
None
...
...
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
@
app
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
):
async
def
create_completion
(
raw_request
:
Request
):
request
=
CompletionRequest
(
**
await
raw_request
.
json
())
logger
.
info
(
f
"Received completion request:
{
request
}
"
)
error_check_ret
=
await
check_model
(
request
)
...
...
@@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
=
request_id
)
request_id
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
...
...
@@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
(
request
.
best_of
is
None
or
request
.
n
==
request
.
best_of
)
and
not
request
.
use_beam_search
)
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
def
create_stream_response_json
(
index
:
int
,
text
:
str
,
logprobs
:
Optional
[
LogProbs
]
=
None
,
...
...
@@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
# Streaming response
if
stream
:
background_tasks
=
BackgroundTasks
()
# Abort the request if the client disconnects.
background_tasks
.
add_task
(
abort_request
)
return
StreamingResponse
(
completion_stream_generator
(),
media_type
=
"text/event-stream"
)
media_type
=
"text/event-stream"
,
background
=
background_tasks
)
# Non-streaming response
final_res
:
RequestOutput
=
None
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
server
.
abort
(
request_id
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"Client disconnected"
)
final_res
=
res
assert
final_res
is
not
None
choices
=
[]
...
...
@@ -276,7 +291,7 @@ if __name__ == "__main__":
help
=
"The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name."
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
parser
=
Async
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
app
.
add_middleware
(
...
...
@@ -291,10 +306,11 @@ if __name__ == "__main__":
served_model
=
args
.
served_model_name
or
args
.
model
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
Async
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
# A separate tokenizer to map token IDs to strings.
tokenizer
=
get_tokenizer
(
args
.
model
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
cacheflow/entrypoints/simple_fastapi_frontend.py
View file @
1a956e13
...
...
@@ -2,15 +2,16 @@ import argparse
import
json
from
typing
import
AsyncGenerator
from
fastapi
import
FastAPI
,
Request
from
fastapi
import
BackgroundTasks
,
FastAPI
,
Request
from
fastapi.responses
import
StreamingResponse
import
uvicorn
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.async_llm_server
import
AsyncLLMServer
from
cacheflow.
server.ray_
utils
import
initialize_cluster
from
cacheflow.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
app
=
FastAPI
()
...
...
@@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
request_dict
=
await
request
.
json
()
prompt
=
request_dict
.
pop
(
"prompt"
)
sampling_params
=
SamplingParams
(
**
request_dict
)
results_generator
=
server
.
generate
(
prompt
,
sampling_params
)
request_id
=
random_uuid
()
results_generator
=
server
.
generate
(
prompt
,
sampling_params
,
request_id
)
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
request_output
in
results_generator
:
...
...
@@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
return
StreamingResponse
(
stream_results
())
async
def
abort_request
()
->
None
:
await
server
.
abort
(
request_id
)
background_tasks
=
BackgroundTasks
()
# Abort the request if the client disconnects.
background_tasks
.
add_task
(
abort_request
)
return
StreamingResponse
(
stream_results
(),
background
=
background_tasks
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8001
)
parser
=
ServerArgs
.
add_cli_args
(
parser
)
parser
=
Async
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
server_args
=
Async
ServerArgs
.
from_cli_args
(
args
)
server
=
AsyncLLMServer
.
from_server_args
(
server_args
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"debug"
,
timeout_keep_alive
=
TIMEOUT_KEEP_ALIVE
)
cacheflow/sequence.py
View file @
1a956e13
...
...
@@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
SWAPPED
=
enum
.
auto
()
FINISHED_STOPPED
=
enum
.
auto
()
FINISHED_LENGTH_CAPPED
=
enum
.
auto
()
FINISHED_ABORTED
=
enum
.
auto
()
@
staticmethod
def
is_finished
(
status
:
"SequenceStatus"
)
->
bool
:
return
status
in
[
SequenceStatus
.
FINISHED_STOPPED
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
,
SequenceStatus
.
FINISHED_ABORTED
,
]
@
staticmethod
...
...
@@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
finish_reason
=
"stop"
elif
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
:
finish_reason
=
"length"
elif
status
==
SequenceStatus
.
FINISHED_ABORTED
:
finish_reason
=
"abort"
else
:
finish_reason
=
None
return
finish_reason
class
SequenceData
:
def
__init__
(
...
...
@@ -137,6 +142,9 @@ class Sequence:
def
get_cumulative_logprob
(
self
)
->
float
:
return
self
.
data
.
cumulative_logprob
def
is_finished
(
self
)
->
bool
:
return
SequenceStatus
.
is_finished
(
self
.
status
)
def
fork
(
self
,
child_seq
:
'Sequence'
)
->
None
:
child_seq
.
logical_token_blocks
=
copy
.
deepcopy
(
self
.
logical_token_blocks
)
child_seq
.
output_logprobs
=
copy
.
deepcopy
(
self
.
output_logprobs
)
...
...
@@ -182,7 +190,7 @@ class SequenceGroup:
raise
ValueError
(
f
'Sequence
{
seq_id
}
not found.'
)
def
is_finished
(
self
)
->
bool
:
return
all
(
S
eq
uenceStatus
.
is_finished
(
seq
.
status
)
for
seq
in
self
.
seqs
)
return
all
(
s
eq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
...
...
cacheflow/server/arg_utils.py
View file @
1a956e13
...
...
@@ -15,7 +15,7 @@ class ServerArgs:
use_dummy_weights
:
bool
=
False
dtype
:
str
=
"default"
seed
:
int
=
0
use_ray
:
bool
=
False
worker_
use_ray
:
bool
=
False
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
block_size
:
int
=
16
...
...
@@ -32,36 +32,6 @@ class ServerArgs:
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
return
_add_server_arguments
(
parser
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"ServerArgs"
:
# Get the list of attributes of this dataclass.
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
# Set the attributes from the parsed arguments.
server_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
server_args
def
create_server_configs
(
self
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Initialize the configs.
model_config
=
ModelConfig
(
self
.
model
,
self
.
download_dir
,
self
.
use_np_weights
,
self
.
use_dummy_weights
,
self
.
dtype
,
self
.
seed
)
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
)
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
use_ray
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
def
_add_server_arguments
(
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
...
...
@@ -69,22 +39,23 @@ def _add_server_arguments(
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
ServerArgs
.
download_dir
,
help
=
'directory to download and load the weights, '
'default to the default cache dir of huggingface'
)
'default to the default cache dir of '
'huggingface'
)
parser
.
add_argument
(
'--use-np-weights'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for
faster
'
'
loading. This can increase the disk
usage by up
'
'
to 2x.'
)
help
=
'save a numpy copy of model weights for '
'faster
loading. This can increase the disk '
'usage by up
to 2x.'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
help
=
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
)
'for BF16 models.'
)
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--
worker-
use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
...
...
@@ -94,24 +65,66 @@ def _add_server_arguments(
default
=
ServerArgs
.
tensor_parallel_size
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
ServerArgs
.
block_size
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
ServerArgs
.
seed
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
ServerArgs
.
swap_space
,
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
ServerArgs
.
swap_space
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
ServerArgs
.
gpu_memory_utilization
,
help
=
'the percentage of GPU memory to be used for
the
'
'
model executor'
)
help
=
'the percentage of GPU memory to be used for'
'the
model executor'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
ServerArgs
.
max_num_batched_tokens
,
help
=
'maximum number of batched tokens per iteration'
)
help
=
'maximum number of batched tokens per '
'iteration'
)
parser
.
add_argument
(
'--max-num-seqs'
,
type
=
int
,
default
=
ServerArgs
.
max_num_seqs
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--disable-log-stats'
,
action
=
'store_true'
,
help
=
'disable logging statistics'
)
return
parser
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
)
->
"ServerArgs"
:
# Get the list of attributes of this dataclass.
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
# Set the attributes from the parsed arguments.
server_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
server_args
def
create_server_configs
(
self
,
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Initialize the configs.
model_config
=
ModelConfig
(
self
.
model
,
self
.
download_dir
,
self
.
use_np_weights
,
self
.
use_dummy_weights
,
self
.
dtype
,
self
.
seed
)
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
)
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
)
return
model_config
,
cache_config
,
parallel_config
,
scheduler_config
@
dataclass
class
AsyncServerArgs
(
ServerArgs
):
server_use_ray
:
bool
=
False
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
,
)
->
argparse
.
ArgumentParser
:
parser
=
ServerArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
'--server-use-ray'
,
action
=
'store_true'
,
help
=
'use Ray to start the LLM server in a '
'separate process as the web server process.'
)
return
parser
cacheflow/server/async_llm_server.py
View file @
1a956e13
...
...
@@ -2,37 +2,52 @@ import asyncio
import
time
from
typing
import
Dict
,
Optional
import
ray
from
cacheflow.logger
import
init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.arg_utils
import
Async
ServerArgs
from
cacheflow.server.llm_server
import
LLMServer
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.utils
import
random_uuid
from
cacheflow.server.ray_utils
import
ray
,
initialize_cluster
logger
=
init_logger
(
__name__
)
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds
class
AsyncLLMServer
:
def
__init__
(
self
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
if
server_use_ray
:
remote_server_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMServer
)
def
__init__
(
self
,
worker_use_ray
:
bool
,
server_use_ray
:
bool
,
*
args
,
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
server_use_ray
=
server_use_ray
if
not
self
.
server_use_ray
:
server_class
=
LLMServer
elif
self
.
worker_use_ray
:
server_class
=
ray
.
remote
(
num_cpus
=
0
)(
LLMServer
).
remote
else
:
remote_server_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMServer
)
self
.
server
=
remote_server_class
.
remote
(
*
args
,
**
kwargs
)
server_class
=
ray
.
remote
(
num_gpus
=
1
)(
LLMServer
).
remote
self
.
server
=
server_class
(
*
args
,
**
kwargs
)
# Request id -> request output.
self
.
request_outputs
:
Dict
[
str
,
RequestOutput
]
=
{}
# Request id -> event to notify that there is new output.
self
.
request_events
:
Dict
[
str
,
asyncio
.
Event
]
=
{}
self
.
is_server_running
=
False
self
.
kicking_request_id
:
Optional
[
str
]
=
None
async
def
server_step
(
self
):
async
def
server_step
(
self
,
kicking_request_id
:
Optional
[
str
]
=
None
):
self
.
is_server_running
=
True
self
.
kicking_request_id
=
kicking_request_id
if
self
.
server_use_ray
:
request_outputs
=
await
self
.
server
.
step
.
remote
()
else
:
# Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new
# requests into the queue.
await
asyncio
.
sleep
(
0
)
request_outputs
=
self
.
server
.
step
()
self
.
is_server_running
=
False
self
.
kicking_request_id
=
None
# Notify the waiting coroutines that there are new outputs ready.
for
request_output
in
request_outputs
:
request_id
=
request_output
.
request_id
...
...
@@ -40,20 +55,26 @@ class AsyncLLMServer:
self
.
request_events
[
request_id
].
set
()
async
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
SamplingParams
,
request_id
:
Optional
[
str
]
=
None
)
->
RequestOutput
:
request_id
:
str
)
->
RequestOutput
:
# Preprocess the request.
arrival_time
=
time
.
time
()
# Create an event to notify us that there is new output from the
# cacheflow server.
if
request_id
is
None
:
request_id
=
random_uuid
()
request_event
=
asyncio
.
Event
()
self
.
request_events
[
request_id
]
=
request_event
logger
.
info
(
f
"Received request
{
request_id
}
: "
f
"prompt:
{
prompt
!
r
}
, "
f
"sampling params:
{
sampling_params
}
."
)
# Add the request into the cacheflow server's waiting queue.
if
self
.
server_use_ray
:
await
self
.
server
.
add_request
.
remote
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
else
:
self
.
server
.
add_request
(
request_id
,
prompt
,
sampling_params
,
arrival_time
=
arrival_time
)
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
...
...
@@ -61,7 +82,7 @@ class AsyncLLMServer:
while
True
:
# Kick the server if the server is not running.
if
not
self
.
is_server_running
:
await
self
.
server_step
()
await
self
.
server_step
(
request_id
)
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
...
...
@@ -80,6 +101,8 @@ class AsyncLLMServer:
# Once finished, release the resources of the sequence group.
if
request_output
.
finished
():
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
del
self
.
request_outputs
[
request_id
]
del
self
.
request_events
[
request_id
]
# Kick the server if the server is not running. This is to
...
...
@@ -89,15 +112,41 @@ class AsyncLLMServer:
await
self
.
server_step
()
break
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
if
request_id
not
in
self
.
request_events
:
# The request has already finished or been aborted.
return
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
if
self
.
server_use_ray
:
await
self
.
server
.
abort_request
.
remote
(
request_id
)
else
:
self
.
server
.
abort_request
(
request_id
)
if
request_id
in
self
.
request_events
:
del
self
.
request_events
[
request_id
]
if
request_id
in
self
.
request_outputs
:
del
self
.
request_outputs
[
request_id
]
# To prevent deadlock when a request is aborted while the server is
# running.
if
self
.
kicking_request_id
==
request_id
:
self
.
is_server_running
=
False
self
.
kicking_request_id
=
None
@
classmethod
def
from_server_args
(
cls
,
server_args
:
ServerArgs
)
->
"AsyncLLMServer"
:
def
from_server_args
(
cls
,
server_args
:
Async
ServerArgs
)
->
"AsyncLLMServer"
:
# Create the server configs.
server_configs
=
server_args
.
create_server_configs
()
parallel_config
=
server_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
)
distributed_init_method
,
devices
=
initialize_cluster
(
parallel_config
,
server_args
.
server_use_ray
)
# Create the LLM server.
server
=
cls
(
server_args
.
use_ray
,
*
server_configs
,
server
=
cls
(
server_args
.
worker_use_ray
,
server_args
.
server_use_ray
,
*
server_configs
,
distributed_init_method
,
devices
,
log_stats
=
not
server_args
.
disable_log_stats
)
return
server
cacheflow/server/llm_server.py
View file @
1a956e13
import
time
from
typing
import
Any
,
List
,
Optional
try
:
import
ray
except
ImportError
:
ray
=
None
from
cacheflow.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
cacheflow.core.scheduler
import
Scheduler
...
...
@@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
from
cacheflow.outputs
import
RequestOutput
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.server.arg_utils
import
ServerArgs
from
cacheflow.server.ray_utils
import
initialize_cluster
from
cacheflow.server.ray_utils
import
ray
,
initialize_cluster
from
cacheflow.server.tokenizer_utils
import
(
get_tokenizer
,
detokenize_incrementally
)
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
...
...
@@ -62,7 +57,7 @@ class LLMServer:
assert
len
(
stage_devices
)
==
1
,
"Only support one stage for now."
for
rank
,
node_resource
,
_
in
stage_devices
[
0
]:
worker_cls
=
Worker
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
worker_cls
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
...
...
@@ -152,6 +147,9 @@ class LLMServer:
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
abort_request
(
self
,
request_id
:
str
)
->
None
:
self
.
scheduler
.
abort_seq_group
(
request_id
)
def
get_num_unfinished_requests
(
self
)
->
int
:
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
...
...
@@ -243,13 +241,13 @@ class LLMServer:
all_outputs
=
[]
for
worker
in
self
.
workers
:
executor
=
getattr
(
worker
,
method
)
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
executor
=
executor
.
remote
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
use_ray
:
if
self
.
parallel_config
.
worker_
use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
if
get_all_outputs
:
...
...
cacheflow/server/ray_utils.py
View file @
1a956e13
...
...
@@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def
initialize_cluster
(
parallel_config
:
ParallelConfig
,
server_use_ray
:
bool
=
False
,
address
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
List
[
DeviceID
]]]:
if
not
parallel_config
.
use_ray
:
if
parallel_config
.
worker_use_ray
or
server_use_ray
:
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
if
not
parallel_config
.
worker_use_ray
:
# Initialize cluster locally.
port
=
random
.
randint
(
10000
,
20000
)
# We need to setup the distributed init method to make sure
...
...
@@ -24,13 +33,6 @@ def initialize_cluster(
all_stage_devices
=
[[(
0
,
None
,
0
)]]
return
distributed_init_method
,
all_stage_devices
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment