Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
4858f3bb
Unverified
Commit
4858f3bb
authored
Apr 30, 2023
by
Zhuohan Li
Committed by
GitHub
Apr 30, 2023
Browse files
Add an option to launch cacheflow without ray (#51)
parent
a96d63c2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
102 additions
and
28 deletions
+102
-28
.gitignore
.gitignore
+3
-0
benchmark/benchmark_latency.py
benchmark/benchmark_latency.py
+8
-4
benchmark/benchmark_text_completion.py
benchmark/benchmark_text_completion.py
+10
-6
cacheflow/http_frontend/fastapi_frontend.py
cacheflow/http_frontend/fastapi_frontend.py
+15
-3
cacheflow/master/server.py
cacheflow/master/server.py
+39
-4
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+21
-9
simple_server.py
simple_server.py
+6
-2
No files found.
.gitignore
View file @
4858f3bb
...
...
@@ -3,8 +3,11 @@
*.egg-info/
*.eggs/
*.so
*.log
*.csv
build/
*.pkl
*.png
**/log.txt
.vscode/
benchmark/benchmark_latency.py
View file @
4858f3bb
...
...
@@ -8,7 +8,8 @@ import torch
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.master.server
import
(
Server
,
add_server_arguments
,
initialize_ray_cluster
)
process_server_arguments
,
initialize_cluster
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
...
...
@@ -20,8 +21,8 @@ def main(args: argparse.Namespace):
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_
ray_
cluster
(
address
=
'local'
,
initialize_cluster
(
use_ray
=
args
.
use_ray
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
...
...
@@ -44,6 +45,7 @@ def main(args: argparse.Namespace):
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
use_ray
=
args
.
use_ray
,
)
# Create a frontend.
...
...
@@ -91,7 +93,8 @@ def main(args: argparse.Namespace):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'CacheFlow simple server.'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Benchmark the latency of decoding a single sentence.'
)
parser
=
add_server_arguments
(
parser
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
128
)
...
...
@@ -99,6 +102,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--use-beam-search'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
args
=
process_server_arguments
(
args
)
args
.
max_num_batched_tokens
=
max
(
args
.
max_num_batched_tokens
,
args
.
batch_size
*
args
.
input_len
)
print
(
args
)
...
...
benchmark/benchmark_text_completion.py
View file @
4858f3bb
...
...
@@ -11,7 +11,8 @@ from transformers import AutoConfig
from
benchmark.trace
import
generate_text_completion_requests
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.master.server
import
(
Server
,
add_server_arguments
,
initialize_ray_cluster
)
process_server_arguments
,
initialize_cluster
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
...
...
@@ -25,8 +26,8 @@ def main(args: argparse.Namespace):
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_
ray_
cluster
(
address
=
'local'
,
initialize_cluster
(
use_ray
=
args
.
use_ray
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
...
...
@@ -49,6 +50,7 @@ def main(args: argparse.Namespace):
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
use_ray
=
args
.
use_ray
,
collect_stats
=
True
,
do_memory_analysis
=
args
.
do_memory_analysis
,
)
...
...
@@ -134,7 +136,7 @@ def main(args: argparse.Namespace):
finished
.
append
({
'group_id'
:
seq_group
.
group_id
,
'seq_id'
:
seq
.
seq_id
,
'arrival_time'
:
arrival_time
,
'arrival_time'
:
arrival_time
,
'finish_time'
:
finish_time
,
'prompt_len'
:
seq
.
prompt_len
,
'output_len'
:
output_len
,
...
...
@@ -225,8 +227,9 @@ def get_sampling_dir_name(
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'CacheFlow simple server.'
)
parser
=
add_server_arguments
(
parser
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Benchmark the performance on a series of requests.'
)
parser
=
add_server_arguments
(
parser
)
parser
.
add_argument
(
'--output-dir'
,
type
=
str
,
help
=
'path to output directory'
,
default
=
None
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
help
=
'path to dataset'
,
required
=
True
)
...
...
@@ -246,6 +249,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--n6-beam'
,
type
=
float
,
help
=
'ratio of requests with n=6 & beam search'
,
default
=
0.0
)
parser
.
add_argument
(
'--n8-beam'
,
type
=
float
,
help
=
'ratio of requests with n=8 & beam search'
,
default
=
0.0
)
args
=
parser
.
parse_args
()
args
=
process_server_arguments
(
args
)
if
args
.
n1
+
args
.
n2
+
args
.
n3
+
args
.
n4
+
args
.
n6
+
args
.
n2_beam
+
args
.
n4_beam
+
args
.
n6_beam
+
args
.
n8_beam
!=
1.0
:
raise
ValueError
(
'The ratios of requests must sum to 1.'
)
...
...
cacheflow/http_frontend/fastapi_frontend.py
View file @
4858f3bb
...
...
@@ -13,7 +13,8 @@ import uvicorn
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
Sequence
,
SequenceGroup
from
cacheflow.master.server
import
(
Server
,
add_server_arguments
,
initialize_ray_cluster
)
process_server_arguments
,
initialize_cluster
)
from
cacheflow.worker.controller
import
DeviceID
from
cacheflow.utils
import
Counter
,
get_gpu_memory
,
get_cpu_memory
...
...
@@ -33,17 +34,22 @@ class FastAPIFrontend:
seed
:
int
,
swap_space
:
int
,
max_num_batched_tokens
:
int
,
max_num_sequences
:
int
,
num_nodes
:
int
,
num_devices_per_node
:
int
,
distributed_init_method
:
str
,
all_stage_devices
:
List
[
List
[
DeviceID
]],
server_use_ray
:
bool
,
):
self
.
block_size
=
block_size
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
self
.
seq_group_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
remote_server_class
=
ray
.
remote
(
num_cpus
=
0
)(
Server
)
if
server_use_ray
:
remote_server_class
=
ray
.
remote
(
num_cpus
=
0
)(
Server
)
else
:
remote_server_class
=
ray
.
remote
(
num_gpus
=
1
)(
Server
)
self
.
server
=
remote_server_class
.
remote
(
model
=
model
,
model_path
=
model_path
,
...
...
@@ -55,12 +61,14 @@ class FastAPIFrontend:
seed
=
seed
,
swap_space
=
swap_space
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_sequences
=
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
use_ray
=
server_use_ray
,
)
self
.
running_seq_groups
:
Dict
[
int
,
SequenceGroup
]
=
{}
...
...
@@ -149,6 +157,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
10002
)
parser
=
add_server_arguments
(
parser
)
args
=
parser
.
parse_args
()
args
=
process_server_arguments
(
args
)
# TODO(zhuohan): Support pipeline parallelism.
assert
args
.
pipeline_parallel_size
==
1
,
(
...
...
@@ -156,7 +165,8 @@ if __name__ == "__main__":
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_ray_cluster
(
initialize_cluster
(
use_ray
=
True
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
...
...
@@ -170,10 +180,12 @@ if __name__ == "__main__":
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
max_num_batched_tokens
=
args
.
max_num_batched_tokens
,
max_num_sequences
=
args
.
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
server_use_ray
=
args
.
use_ray
,
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
"info"
)
cacheflow/master/server.py
View file @
4858f3bb
import
argparse
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
,
Optional
import
random
import
ray
import
torch
try
:
import
ray
except
ImportError
:
ray
=
None
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.models
import
get_memory_analyzer
...
...
@@ -31,6 +35,7 @@ class Server:
all_stage_devices
:
List
[
List
[
DeviceID
]],
gpu_memory
:
int
,
cpu_memory
:
int
,
use_ray
:
bool
,
collect_stats
:
bool
=
False
,
do_memory_analysis
:
bool
=
False
,
):
...
...
@@ -38,6 +43,10 @@ class Server:
self
.
num_devices_per_node
=
num_devices_per_node
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
if
not
use_ray
:
assert
self
.
world_size
==
1
,
(
"Only support single GPU without Ray."
)
self
.
memory_analyzer
=
get_memory_analyzer
(
model_name
=
model
,
block_size
=
block_size
,
...
...
@@ -72,6 +81,7 @@ class Server:
model_path
=
model_path
,
use_dummy_weights
=
use_dummy_weights
,
max_num_batched_tokens
=
max_num_batched_tokens
,
use_ray
=
use_ray
,
)
self
.
controllers
.
append
(
controller
)
...
...
@@ -105,11 +115,30 @@ class Server:
self
.
scheduler
.
swapped
)
def
initialize_ray_cluster
(
address
:
str
=
'auto'
,
def
initialize_cluster
(
use_ray
:
bool
=
False
,
address
:
Optional
[
str
]
=
None
,
pipeline_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
)
->
Tuple
[
int
,
int
,
str
,
List
[
List
[
DeviceID
]]]:
# Initialize cluster locally.
if
not
use_ray
:
assert
pipeline_parallel_size
*
tensor_parallel_size
==
1
,
(
"Only support single GPU without Ray."
)
num_nodes
=
1
num_devices_per_node
=
torch
.
cuda
.
device_count
()
port
=
random
.
randint
(
10000
,
20000
)
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method
=
f
"tcp://localhost:
{
port
}
"
all_stage_devices
=
[[(
0
,
None
,
0
)]]
return
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
assert
ray
is
not
None
,
(
"Ray is not installed. Please install Ray to use distributed "
"serving."
)
# Connect to a ray cluster.
ray
.
init
(
address
=
address
)
...
...
@@ -177,6 +206,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--model-path'
,
type
=
str
,
default
=
'~/.cacheflow/model_weights'
,
help
=
'model path to download and load the weights'
)
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
...
...
@@ -190,3 +220,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--max-num-sequences'
,
type
=
int
,
default
=
256
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
return
parser
def
process_server_arguments
(
args
:
argparse
.
Namespace
):
if
args
.
pipeline_parallel_size
*
args
.
tensor_parallel_size
>
1
:
args
.
use_ray
=
True
return
args
cacheflow/worker/controller.py
View file @
4858f3bb
from
typing
import
Dict
,
List
,
Union
,
Tuple
import
ray
try
:
import
ray
except
ImportError
:
ray
=
None
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.sequence
import
SequenceGroupInputs
...
...
@@ -29,6 +32,7 @@ class Controller:
model_path
:
str
,
use_dummy_weights
:
bool
,
max_num_batched_tokens
:
int
,
use_ray
:
bool
,
)
->
None
:
self
.
stage_id
=
stage_id
self
.
stage_devices
=
stage_devices
...
...
@@ -36,6 +40,7 @@ class Controller:
self
.
block_size
=
block_size
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
num_cpu_blocks
=
num_cpu_blocks
self
.
use_ray
=
use_ray
# Which pipeline stage is this node assigned to?
self
.
is_first_stage
=
stage_id
==
0
...
...
@@ -43,10 +48,13 @@ class Controller:
self
.
workers
:
List
[
Worker
]
=
[]
for
rank
,
node_resource
,
device_id
in
stage_devices
:
worker_cls
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
resources
=
{
node_resource
:
1e-5
})(
Worker
)
worker
=
worker_cls
.
remote
(
if
self
.
use_ray
:
worker_cls
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
resources
=
{
node_resource
:
1e-5
})(
Worker
).
remote
else
:
worker_cls
=
Worker
worker
=
worker_cls
(
model_name
=
model_name
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
...
...
@@ -78,17 +86,21 @@ class Controller:
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
future
s
=
[]
all_output
s
=
[]
for
worker
in
self
.
workers
:
future
=
worker
.
execute_stage
.
remote
(
executor
=
(
worker
.
execute_stage
.
remote
if
self
.
use_ray
else
worker
.
execute_stage
)
output
=
executor
(
input_seq_groups
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
)
futures
.
append
(
future
)
all_outputs
.
append
(
output
)
if
self
.
use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
all_outputs
=
ray
.
get
(
futures
)
# Make sure all workers have the same results.
output
=
all_outputs
[
0
]
for
other_output
in
all_outputs
[
1
:]:
...
...
simple_server.py
View file @
4858f3bb
...
...
@@ -3,7 +3,8 @@ from typing import List
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.master.server
import
(
Server
,
add_server_arguments
,
initialize_ray_cluster
)
process_server_arguments
,
initialize_cluster
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
...
...
@@ -14,7 +15,8 @@ def main(args: argparse.Namespace):
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_ray_cluster
(
initialize_cluster
(
use_ray
=
args
.
use_ray
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
...
...
@@ -37,6 +39,7 @@ def main(args: argparse.Namespace):
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
use_ray
=
args
.
use_ray
,
)
# Create a frontend.
...
...
@@ -70,4 +73,5 @@ if __name__ == '__main__':
parser
=
argparse
.
ArgumentParser
(
description
=
'CacheFlow simple server.'
)
parser
=
add_server_arguments
(
parser
)
args
=
parser
.
parse_args
()
args
=
process_server_arguments
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment