Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SIYIXNI
vllm
Commits
2f49f155
"vscode:/vscode.git/clone" did not exist on "5c4471ef64c176b709ea3fe5ae62afb2355c8b3c"
Unverified
Commit
2f49f155
authored
Mar 22, 2023
by
Zhuohan Li
Committed by
GitHub
Mar 21, 2023
Browse files
Support tensor parallel (#2)
parent
cfae35b8
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
212 additions
and
70 deletions
+212
-70
cacheflow/worker/cache_engine.py
cacheflow/worker/cache_engine.py
+4
-6
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+42
-19
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+51
-22
server.py
server.py
+115
-23
No files found.
cacheflow/worker/cache_engine.py
View file @
2f49f155
...
@@ -11,7 +11,6 @@ class CacheEngine:
...
@@ -11,7 +11,6 @@ class CacheEngine:
def
__init__
(
def
__init__
(
self
,
self
,
worker_id
:
int
,
worker_id
:
int
,
gpu_id
:
int
,
num_layers
:
int
,
num_layers
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -25,7 +24,6 @@ class CacheEngine:
...
@@ -25,7 +24,6 @@ class CacheEngine:
f
'head_size (
{
head_size
}
) must be a multiple of 16.'
)
f
'head_size (
{
head_size
}
) must be a multiple of 16.'
)
self
.
worker_id
=
worker_id
self
.
worker_id
=
worker_id
self
.
gpu_id
=
gpu_id
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -39,8 +37,8 @@ class CacheEngine:
...
@@ -39,8 +37,8 @@ class CacheEngine:
self
.
cpu_cache
=
self
.
allocate_cpu_cache
()
self
.
cpu_cache
=
self
.
allocate_cpu_cache
()
# Initialize the stream for caching operations.
# Initialize the stream for caching operations.
self
.
cache_stream
=
torch
.
cuda
.
Stream
(
device
=
gpu_id
)
self
.
cache_stream
=
torch
.
cuda
.
Stream
()
assert
self
.
cache_stream
!=
torch
.
cuda
.
current_stream
(
device
=
gpu_id
)
assert
self
.
cache_stream
!=
torch
.
cuda
.
current_stream
()
# Initialize the events for stream synchronization.
# Initialize the events for stream synchronization.
self
.
events
=
[
torch
.
cuda
.
Event
()
for
_
in
range
(
num_layers
)]
self
.
events
=
[
torch
.
cuda
.
Event
()
for
_
in
range
(
num_layers
)]
...
@@ -69,12 +67,12 @@ class CacheEngine:
...
@@ -69,12 +67,12 @@ class CacheEngine:
key_blocks
=
torch
.
empty
(
key_blocks
=
torch
.
empty
(
size
=
(
self
.
num_gpu_blocks
,
*
key_block_shape
),
size
=
(
self
.
num_gpu_blocks
,
*
key_block_shape
),
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
gpu_id
,
device
=
"cuda"
,
)
)
value_blocks
=
torch
.
empty
(
value_blocks
=
torch
.
empty
(
size
=
(
self
.
num_gpu_blocks
,
*
value_block_shape
),
size
=
(
self
.
num_gpu_blocks
,
*
value_block_shape
),
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
gpu_id
,
device
=
"cuda"
,
)
)
gpu_cache
.
append
((
key_blocks
,
value_blocks
))
gpu_cache
.
append
((
key_blocks
,
value_blocks
))
return
gpu_cache
return
gpu_cache
...
...
cacheflow/worker/controller.py
View file @
2f49f155
from
typing
import
Dict
,
List
,
Union
from
typing
import
Dict
,
List
,
Union
,
Tuple
import
ray
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.worker.worker
import
Worker
from
cacheflow.worker.worker
import
Worker
DeviceID
=
Tuple
[
int
,
str
,
int
]
# rank, node resource (node IP), device id
class
Controller
:
class
Controller
:
def
__init__
(
def
__init__
(
self
,
self
,
node_id
:
int
,
stage_id
:
int
,
num_workers
:
int
,
stage_devices
:
List
[
DeviceID
],
world_size
:
int
,
tensor_parallel_size
:
int
,
pipeline_parallel_size
:
int
,
distributed_init_method
:
str
,
model_name
:
str
,
model_name
:
str
,
block_size
:
int
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
str
,
dtype
:
str
,
seed
:
int
,
seed
:
int
,
model_path
:
str
,
)
->
None
:
)
->
None
:
self
.
nod
e_id
=
nod
e_id
self
.
stag
e_id
=
stag
e_id
self
.
num_workers
=
num_worker
s
self
.
stage_devices
=
stage_device
s
self
.
model_name
=
model_name
self
.
model_name
=
model_name
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
num_cpu_blocks
=
num_cpu_blocks
self
.
num_cpu_blocks
=
num_cpu_blocks
# Which pipeline stage is this node assigned to?
# Which pipeline stage is this node assigned to?
self
.
is_first_stage
=
nod
e_id
==
0
self
.
is_first_stage
=
stag
e_id
==
0
self
.
is_last_stage
=
False
self
.
is_last_stage
=
False
self
.
workers
:
List
[
Worker
]
=
[]
self
.
workers
:
List
[
Worker
]
=
[]
for
i
in
range
(
num_workers
):
for
rank
,
node_resource
,
device_id
in
stage_devices
:
worker
=
Worker
(
worker_cls
=
ray
.
remote
(
num_cpus
=
0
,
worker_id
=
node_id
+
i
,
num_gpus
=
1
,
gpu_id
=
i
,
resources
=
{
node_resource
:
1e-5
})(
Worker
)
worker
=
worker_cls
.
remote
(
model_name
=
model_name
,
model_name
=
model_name
,
block_size
=
block_size
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
dtype
=
dtype
,
dtype
=
dtype
,
seed
=
seed
,
seed
=
seed
,
distributed_init_method
=
distributed_init_method
,
rank
=
rank
,
world_size
=
world_size
,
tensor_parallel_size
=
tensor_parallel_size
,
pipeline_parallel_size
=
pipeline_parallel_size
,
model_path
=
model_path
,
)
)
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
...
@@ -57,15 +74,21 @@ class Controller:
...
@@ -57,15 +74,21 @@ class Controller:
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
)
->
None
:
# FIXME: Support tensor parallelism.
futures
=
[]
assert
len
(
self
.
workers
)
==
1
for
worker
in
self
.
workers
:
worker
=
self
.
workers
[
0
]
future
=
worker
.
execute_stage
.
remote
(
output
=
worker
.
execute_stage
(
input_seq_groups
,
input_seq_groups
,
blocks_to_swap_in
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_swap_out
,
blocks_to_copy
,
blocks_to_copy
,
)
)
futures
.
append
(
future
)
all_outputs
=
ray
.
get
(
futures
)
# Make sure all workers have the same results.
output
=
all_outputs
[
0
]
for
other_output
in
all_outputs
[
1
:]:
assert
output
==
other_output
if
self
.
is_last_stage
:
if
self
.
is_last_stage
:
self
.
next_node
.
post_step
(
output
)
self
.
next_node
.
post_step
(
output
)
...
...
cacheflow/worker/worker.py
View file @
2f49f155
...
@@ -3,49 +3,58 @@ from typing import Dict, List, Tuple
...
@@ -3,49 +3,58 @@ from typing import Dict, List, Tuple
import
torch
import
torch
from
cacheflow.models
import
get_model
from
cacheflow.models
import
get_model
from
cacheflow.models
import
set_seed
from
cacheflow.models
import
InputMetadata
from
cacheflow.models
import
InputMetadata
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceGroupInputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.sequence
import
SequenceOutputs
from
cacheflow.worker.cache_engine
import
CacheEngine
from
cacheflow.worker.cache_engine
import
CacheEngine
from
cacheflow.parallel_utils.parallel_state
import
(
initialize_model_parallel
,
get_tensor_model_parallel_world_size
)
from
cacheflow.utils
import
set_random_seed
class
Worker
:
class
Worker
:
def
__init__
(
def
__init__
(
self
,
self
,
worker_id
:
int
,
gpu_id
:
int
,
model_name
:
str
,
model_name
:
str
,
block_size
:
int
,
block_size
:
int
,
num_gpu_blocks
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
num_cpu_blocks
:
int
,
dtype
:
str
,
dtype
:
str
,
seed
:
int
,
seed
:
int
,
distributed_init_method
:
str
,
rank
:
int
,
world_size
:
int
,
model_path
:
str
,
tensor_parallel_size
:
int
=
1
,
pipeline_parallel_size
:
int
=
1
,
)
->
None
:
)
->
None
:
self
.
worker_id
=
worker_id
self
.
init_distributed_environment
(
distributed_init_method
,
self
.
gpu_id
=
gpu_id
rank
,
world_size
,
tensor_parallel_size
,
pipeline_parallel_size
)
self
.
worker_id
=
rank
self
.
block_size
=
block_size
self
.
block_size
=
block_size
set_random_seed
(
seed
)
self
.
device
=
torch
.
device
(
'cuda'
,
index
=
gpu_id
)
# Initialize the model.
# Initialize the model.
# FIXME(woosuk): This is a hack.
self
.
model
,
self
.
dtype
=
get_model
(
model_name
,
dtype
=
dtype
,
path
=
model_path
)
self
.
model
=
get_model
(
model_name
,
dtype
=
dtype
).
to
(
device
=
self
.
device
)
self
.
model
=
self
.
model
.
cuda
()
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
self
.
num_layers
=
self
.
model
.
config
.
num_hidden_layers
self
.
num_layers
=
self
.
model
.
config
.
num_hidden_layers
self
.
num_heads
=
self
.
model
.
config
.
num_attention_heads
assert
self
.
model
.
config
.
num_attention_heads
%
tensor_model_parallel_world_size
==
0
self
.
head_size
=
self
.
model
.
config
.
hidden_size
//
self
.
num_heads
self
.
num_heads
=
self
.
model
.
config
.
num_attention_heads
//
tensor_model_parallel_world_size
self
.
dtyp
e
=
self
.
model
.
dtype
self
.
head_siz
e
=
self
.
model
.
config
.
hidden_size
//
(
self
.
num_heads
*
tensor_model_parallel_world_size
)
# Set the seed.
# We reset the seed after initializing the model to ensure that
# We set the seed after initializing the model to ensure that
# the random state is not affected by the model initialization.
# the random state is not affected by the model initialization.
set_seed
(
seed
)
set_
random_
seed
(
seed
)
self
.
cache_engine
=
CacheEngine
(
self
.
cache_engine
=
CacheEngine
(
worker_id
=
worker_id
,
worker_id
=
self
.
worker_id
,
gpu_id
=
gpu_id
,
num_layers
=
self
.
num_layers
,
num_layers
=
self
.
num_layers
,
num_heads
=
self
.
num_heads
,
num_heads
=
self
.
num_heads
,
head_size
=
self
.
head_size
,
head_size
=
self
.
head_size
,
...
@@ -57,6 +66,26 @@ class Worker:
...
@@ -57,6 +66,26 @@ class Worker:
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
def
init_distributed_environment
(
self
,
distributed_init_method
:
str
,
rank
:
int
,
world_size
:
int
,
tensor_parallel_size
:
int
=
1
,
pipeline_parallel_size
:
int
=
1
)
->
None
:
"""Initialize the distributed environment."""
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
distributed_init_method
,
world_size
=
world_size
,
rank
=
rank
,
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
initialize_model_parallel
(
tensor_parallel_size
,
pipeline_parallel_size
)
def
prepare_inputs
(
def
prepare_inputs
(
self
,
self
,
input_seq_groups
:
List
[
SequenceGroupInputs
],
input_seq_groups
:
List
[
SequenceGroupInputs
],
...
@@ -142,18 +171,18 @@ class Worker:
...
@@ -142,18 +171,18 @@ class Worker:
# Convert to tensors.
# Convert to tensors.
tokens_tensor
=
torch
.
tensor
(
tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_tokens
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
positions_tensor
=
torch
.
tensor
(
positions_tensor
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
context_lens_tensor
=
torch
.
tensor
(
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
padded_block_tables
=
[
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
)
for
block_table
in
generation_block_tables
]
for
block_table
in
generation_block_tables
]
block_tables_tensor
=
torch
.
tensor
(
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_groups
=
seq_groups
,
...
...
server.py
View file @
2f49f155
import
argparse
import
argparse
from
typing
import
List
import
random
from
typing
import
List
,
Tuple
,
Dict
import
ray
from
cacheflow.master.frontend
import
Frontend
from
cacheflow.master.frontend
import
Frontend
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.models
import
get_memory_analyzer
from
cacheflow.models
import
get_memory_analyzer
from
cacheflow.worker.controller
import
Controller
from
cacheflow.worker.controller
import
Controller
,
DeviceID
parser
=
argparse
.
ArgumentParser
(
description
=
'CacheFlow server'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
def
initialize_ray_cluster
(
parser
.
add_argument
(
'--num-nodes'
,
type
=
int
,
default
=
1
,
help
=
'number of nodes'
)
address
:
str
=
'auto'
,
parser
.
add_argument
(
'--num-workers'
,
type
=
int
,
default
=
1
,
help
=
'number of workers per node'
)
pipeline_parallel_size
:
int
=
1
,
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
],
help
=
'token block size'
)
tensor_parallel_size
:
int
=
1
,
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
)
->
Tuple
[
int
,
int
,
str
,
List
[
List
[
DeviceID
]]]:
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# Connect to a ray cluster.
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
ray
.
init
(
address
=
address
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
20
,
help
=
'CPU swap space size (GiB) per GPU'
)
# Assume we have a uniform cluster that each node has the same number of
parser
.
add_argument
(
'--max-batch-size'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens'
)
# GPUs for now.
args
=
parser
.
parse_args
()
valid_node_resources
=
[]
num_devices_per_node
=
None
for
node
in
ray
.
nodes
():
def
main
():
if
(
not
node
[
'Alive'
])
or
node
[
'Resources'
][
'GPU'
]
<=
0
:
continue
if
num_devices_per_node
is
None
:
num_devices_per_node
=
node
[
'Resources'
][
'GPU'
]
else
:
assert
num_devices_per_node
==
node
[
'Resources'
][
'GPU'
],
(
"The number of GPUs per node is not uniform."
)
for
key
in
node
[
'Resources'
]:
if
key
.
startswith
(
'node:'
):
valid_node_resources
.
append
(
key
)
num_nodes
=
len
(
valid_node_resources
)
assert
(
pipeline_parallel_size
*
tensor_parallel_size
<=
num_nodes
*
num_devices_per_node
),
(
"The number of required GPUs exceeds the total number of "
"available GPUs."
)
if
tensor_parallel_size
>=
num_devices_per_node
:
assert
tensor_parallel_size
%
num_devices_per_node
==
0
,
(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node."
)
else
:
assert
num_devices_per_node
%
tensor_parallel_size
==
0
,
(
"The number of GPUs per node is not divisible by the number "
"of tensor parallelism."
)
# Assign GPUs to pipeline stages.
rank
=
0
current_node_id
=
0
current_device_id
=
0
distributed_init_method
=
None
all_stage_devices
=
[]
for
i
in
range
(
pipeline_parallel_size
):
stage_devices
=
[]
for
j
in
range
(
tensor_parallel_size
):
node_resource
=
valid_node_resources
[
current_node_id
]
stage_devices
.
append
((
rank
,
node_resource
,
current_device_id
))
if
distributed_init_method
is
None
:
ip
=
node_resource
.
split
(
"node:"
)[
-
1
]
port
=
random
.
randint
(
10000
,
20000
)
distributed_init_method
=
f
"tcp://
{
ip
}
:
{
port
}
"
rank
+=
1
current_device_id
+=
1
if
current_device_id
>=
num_devices_per_node
:
current_node_id
+=
1
current_device_id
=
0
all_stage_devices
.
append
(
stage_devices
)
return
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
def
main
(
args
:
argparse
.
Namespace
):
# TODO(zhuohan): Support pipeline parallelism.
assert
args
.
pipeline_parallel_size
==
1
,
(
'Pipeline parallelism is not supported yet.'
)
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_ray_cluster
(
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
world_size
=
args
.
pipeline_parallel_size
*
args
.
tensor_parallel_size
memory_analyzer
=
get_memory_analyzer
(
memory_analyzer
=
get_memory_analyzer
(
model_name
=
args
.
model
,
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
dtype
=
args
.
dtype
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
)
)
num_gpu_blocks
=
memory_analyzer
.
get_max_num_gpu_blocks
(
num_gpu_blocks
=
memory_analyzer
.
get_max_num_gpu_blocks
(
max_num_batched_tokens
=
args
.
max_batch_size
)
max_num_batched_tokens
=
args
.
max_batch_size
)
...
@@ -32,18 +101,23 @@ def main():
...
@@ -32,18 +101,23 @@ def main():
swap_space
=
args
.
swap_space
)
swap_space
=
args
.
swap_space
)
print
(
f
'# GPU blocks:
{
num_gpu_blocks
}
, # CPU blocks:
{
num_cpu_blocks
}
'
)
print
(
f
'# GPU blocks:
{
num_gpu_blocks
}
, # CPU blocks:
{
num_cpu_blocks
}
'
)
# Create a controller for each
nod
e.
# Create a controller for each
pipeline stag
e.
controllers
:
List
[
Controller
]
=
[]
controllers
:
List
[
Controller
]
=
[]
for
i
in
range
(
args
.
num_nodes
):
for
i
in
range
(
args
.
pipeline_parallel_size
):
controller
=
Controller
(
controller
=
Controller
(
node_id
=
i
,
stage_id
=
i
,
num_workers
=
args
.
num_workers
,
stage_devices
=
all_stage_devices
[
i
],
world_size
=
world_size
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
distributed_init_method
=
distributed_init_method
,
model_name
=
args
.
model
,
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
block_size
=
args
.
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
dtype
=
args
.
dtype
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
seed
=
args
.
seed
,
model_path
=
args
.
model_path
,
)
)
controllers
.
append
(
controller
)
controllers
.
append
(
controller
)
...
@@ -83,4 +157,22 @@ def main():
...
@@ -83,4 +157,22 @@ def main():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
parser
=
argparse
.
ArgumentParser
(
description
=
'CacheFlow server'
)
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--model-path'
,
type
=
str
,
default
=
'~/.cacheflow/model_weights'
,
help
=
'model path to download and load the weights'
)
# Parallel arguments
parser
.
add_argument
(
'--pipeline-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
8
,
choices
=
[
8
,
16
],
help
=
'token block size'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'float'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
20
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--max-batch-size'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens'
)
args
=
parser
.
parse_args
()
main
(
args
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment