Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
27f1410d
"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "03f5c679e6d850fc97c3f64a64784e16d226f1f6"
Unverified
Commit
27f1410d
authored
May 03, 2023
by
Zhuohan Li
Committed by
GitHub
May 03, 2023
Browse files
New weight loader without np copy (#52)
parent
4858f3bb
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
289 additions
and
357 deletions
+289
-357
benchmark/benchmark_latency.py
benchmark/benchmark_latency.py
+4
-42
benchmark/benchmark_text_completion.py
benchmark/benchmark_text_completion.py
+4
-43
cacheflow/http_frontend/fastapi_frontend.py
cacheflow/http_frontend/fastapi_frontend.py
+10
-7
cacheflow/master/server.py
cacheflow/master/server.py
+58
-7
cacheflow/models/gpt_neox.py
cacheflow/models/gpt_neox.py
+15
-59
cacheflow/models/llama.py
cacheflow/models/llama.py
+51
-73
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+5
-6
cacheflow/models/opt.py
cacheflow/models/opt.py
+34
-69
cacheflow/models/utils.py
cacheflow/models/utils.py
+94
-1
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+5
-3
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+5
-3
simple_server.py
simple_server.py
+4
-44
No files found.
benchmark/benchmark_latency.py
View file @
27f1410d
...
...
@@ -6,53 +6,15 @@ from tqdm import tqdm
import
numpy
as
np
import
torch
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.master.server
import
(
Server
,
add_server_arguments
,
process_server_arguments
,
initialize_cluster
)
from
cacheflow.master.server
import
(
add_server_arguments
,
process_server_arguments
,
init_local_server_and_frontend_with_arguments
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
def
main
(
args
:
argparse
.
Namespace
):
# TODO(zhuohan): Support pipeline parallelism.
assert
args
.
pipeline_parallel_size
==
1
,
(
'Pipeline parallelism is not supported yet.'
)
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_cluster
(
use_ray
=
args
.
use_ray
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
# Create a server.
server
=
Server
(
model
=
args
.
model
,
model_path
=
args
.
model_path
,
use_dummy_weights
=
args
.
use_dummy_weights
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
max_num_batched_tokens
=
args
.
max_num_batched_tokens
,
max_num_sequences
=
args
.
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
use_ray
=
args
.
use_ray
,
)
# Create a frontend.
frontend
=
SimpleFrontend
(
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
)
sampling_params_dict
=
{
'n'
:
args
.
n
,
'temperature'
:
0.0
if
args
.
use_beam_search
else
1.0
,
...
...
benchmark/benchmark_text_completion.py
View file @
27f1410d
...
...
@@ -9,57 +9,18 @@ from tqdm import tqdm
from
transformers
import
AutoConfig
from
benchmark.trace
import
generate_text_completion_requests
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.master.server
import
(
Server
,
add_server_arguments
,
process_server_arguments
,
initialize_cluster
)
from
cacheflow.master.server
import
(
add_server_arguments
,
process_server_arguments
,
init_local_server_and_frontend_with_arguments
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
logger
=
logging
.
getLogger
(
__name__
)
def
main
(
args
:
argparse
.
Namespace
):
assert
args
.
pipeline_parallel_size
==
1
,
(
'Pipeline parallelism is not supported yet.'
)
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_cluster
(
use_ray
=
args
.
use_ray
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
# Create a server.
server
=
Server
(
model
=
args
.
model
,
model_path
=
args
.
model_path
,
use_dummy_weights
=
args
.
use_dummy_weights
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
max_num_batched_tokens
=
args
.
max_num_batched_tokens
,
max_num_sequences
=
args
.
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
use_ray
=
args
.
use_ray
,
collect_stats
=
True
,
do_memory_analysis
=
args
.
do_memory_analysis
,
)
# Create a frontend.
frontend
=
SimpleFrontend
(
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
)
# Generate requests.
requests
=
generate_text_completion_requests
(
args
.
dataset
,
...
...
cacheflow/http_frontend/fastapi_frontend.py
View file @
27f1410d
import
argparse
import
asyncio
import
time
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Optional
import
json
import
ray
...
...
@@ -22,11 +22,12 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app
=
FastAPI
()
class
FastAPI
Frontend
:
class
FastAPI
Server
:
def
__init__
(
self
,
model
:
str
,
model_path
:
str
,
cache_dir
:
Optional
[
str
],
use_np_cache
:
bool
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
block_size
:
int
,
...
...
@@ -52,8 +53,9 @@ class FastAPIFrontend:
remote_server_class
=
ray
.
remote
(
num_gpus
=
1
)(
Server
)
self
.
server
=
remote_server_class
.
remote
(
model
=
model
,
model_path
=
model_path
,
cache_dir
=
cache_dir
,
use_dummy_weights
=
False
,
use_np_cache
=
use_np_cache
,
pipeline_parallel_size
=
pipeline_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
block_size
=
block_size
,
...
...
@@ -148,7 +150,7 @@ class FastAPIFrontend:
@
app
.
post
(
"/generate"
)
async
def
generate_stream
(
request
:
Request
):
request_dict
=
await
request
.
json
()
return
StreamingResponse
(
frontend
.
generate
(
request_dict
))
return
StreamingResponse
(
server
.
generate
(
request_dict
))
if
__name__
==
"__main__"
:
...
...
@@ -170,9 +172,10 @@ if __name__ == "__main__":
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
frontend
=
FastAPI
Frontend
(
server
=
FastAPI
Server
(
model
=
args
.
model
,
model_path
=
args
.
model_path
,
cache_dir
=
args
.
cache_dir
,
use_np_cache
=
args
.
use_np_cache
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
...
...
cacheflow/master/server.py
View file @
27f1410d
...
...
@@ -9,18 +9,21 @@ except ImportError:
ray
=
None
from
cacheflow.master.scheduler
import
Scheduler
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.models
import
get_memory_analyzer
from
cacheflow.worker.controller
import
Controller
,
DeviceID
from
cacheflow.sequence
import
SequenceGroup
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
class
Server
:
def
__init__
(
self
,
model
:
str
,
model_path
:
str
,
cache_dir
:
Optional
[
str
]
,
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
block_size
:
int
,
...
...
@@ -78,8 +81,9 @@ class Server:
num_cpu_blocks
=
self
.
num_cpu_blocks
,
dtype
=
dtype
,
seed
=
seed
,
model_path
=
model_path
,
cache_dir
=
cache_dir
,
use_dummy_weights
=
use_dummy_weights
,
use_np_cache
=
use_np_cache
,
max_num_batched_tokens
=
max_num_batched_tokens
,
use_ray
=
use_ray
,
)
...
...
@@ -203,25 +207,72 @@ def initialize_cluster(
def
add_server_arguments
(
parser
:
argparse
.
ArgumentParser
):
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'model name'
)
parser
.
add_argument
(
'--model-path'
,
type
=
str
,
default
=
'~/.cacheflow/model_weights'
,
help
=
'model path to download and load the weights'
)
parser
.
add_argument
(
'--cache-dir'
,
type
=
str
,
default
=
None
,
help
=
'cache dir to download and load the weights, '
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
],
help
=
'data type'
)
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
,
help
=
'number of tensor parallel replicas'
)
# KV cache arguments
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
default
=
16
,
choices
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
],
help
=
'token block size'
)
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
],
help
=
'data type'
)
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
20
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens per iteration'
)
parser
.
add_argument
(
'--max-num-sequences'
,
type
=
int
,
default
=
256
,
help
=
'maximum number of sequences per iteration'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
return
parser
def
process_server_arguments
(
args
:
argparse
.
Namespace
):
if
args
.
pipeline_parallel_size
*
args
.
tensor_parallel_size
>
1
:
args
.
use_ray
=
True
return
args
def
init_local_server_and_frontend_with_arguments
(
args
:
argparse
.
Namespace
):
# TODO(zhuohan): Support pipeline parallelism.
assert
args
.
pipeline_parallel_size
==
1
,
(
'Pipeline parallelism is not supported yet.'
)
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_cluster
(
use_ray
=
args
.
use_ray
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
# Create a server.
server
=
Server
(
model
=
args
.
model
,
cache_dir
=
args
.
cache_dir
,
use_dummy_weights
=
args
.
use_dummy_weights
,
use_np_cache
=
args
.
use_np_cache
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
max_num_batched_tokens
=
args
.
max_num_batched_tokens
,
max_num_sequences
=
args
.
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
use_ray
=
args
.
use_ray
,
)
# Create a frontend.
frontend
=
SimpleFrontend
(
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
)
return
server
,
frontend
cacheflow/models/gpt_neox.py
View file @
27f1410d
"""1D GPT-NeoX model compatible with HuggingFace weights."""
import
os
import
glob
import
filelock
from
tqdm
import
tqdm
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
huggingface_hub
import
snapshot_download
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
GPTNeoXCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
cacheflow.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
...
...
@@ -196,17 +192,22 @@ class GPTNeoXForCausalLM(nn.Module):
_column_parallel_weights
=
[
"embed_in.weight"
,
"embed_out.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_row_parallel_weights
=
[
"dense.weight"
,
"dense_4h_to_h.weight"
]
def
load_weights
(
self
,
weights_path
:
str
):
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
param
in
state_dict
.
items
():
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
or
"rotary_emb.inv_freq"
in
name
):
continue
param
=
state_dict
[
name
]
if
"query_key_value"
in
name
:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size, num_heads * head_size], while the
# required shape is [3 * num_heads * head_size, num_heads * head_size].
# Thus, we need weight conversion.
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
...
...
@@ -223,55 +224,10 @@ class GPTNeoXForCausalLM(nn.Module):
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
).
contiguous
()
else
:
assert
False
else
:
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
@
staticmethod
def
get_weights
(
model_name
:
str
,
path
:
str
):
path
=
os
.
path
.
join
(
path
,
f
"
{
model_name
}
-np"
)
path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
os
.
makedirs
(
path
,
exist_ok
=
True
)
lock_path
=
os
.
path
.
join
(
path
,
"file_lock"
)
lock
=
filelock
.
FileLock
(
lock_path
)
with
lock
:
test_weight_path
=
os
.
path
.
join
(
path
,
"gpt_neox.embed_in.weight"
)
if
os
.
path
.
exists
(
test_weight_path
):
return
path
folder
=
snapshot_download
(
model_name
,
allow_patterns
=
"*.bin"
,
cache_dir
=
os
.
path
.
join
(
path
,
"cache"
))
bin_files
=
glob
.
glob
(
os
.
path
.
join
(
folder
,
"*.bin"
))
for
bin_file
in
tqdm
(
bin_files
,
desc
=
"Convert format"
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
tqdm
(
state
.
items
(),
leave
=
False
):
param_path
=
os
.
path
.
join
(
path
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
return
path
raise
ValueError
(
f
"Unexpected weight name:
{
name
}
"
)
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
)
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
...
...
cacheflow/models/llama.py
View file @
27f1410d
"""1D LLaMA model compatible with HuggingFace weights."""
import
os
import
glob
import
filelock
from
tqdm
import
tqdm
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
...
...
@@ -15,6 +10,8 @@ from cacheflow.models.activation import SiluAndMul
from
cacheflow.models.attention
import
LlamaCacheFlowAttention
from
cacheflow.models.layernorm
import
RMSNorm
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
cacheflow.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
...
...
@@ -216,76 +213,57 @@ class LlamaForCausalLM(nn.Module):
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
def
load_weights
(
self
,
weights_path
:
str
):
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
param
in
state_dict
.
items
():
if
"qkv_proj"
in
name
or
"gate_up_proj"
in
name
:
if
"qkv_proj"
in
name
:
original_name
=
"qkv_proj"
weight_names
=
[
"q_proj"
,
"k_proj"
,
"v_proj"
]
shard_size
=
param
.
shape
[
0
]
//
3
else
:
original_name
=
"gate_up_proj"
weight_names
=
[
"gate_proj"
,
"up_proj"
]
shard_size
=
param
.
shape
[
0
]
//
2
weights_to_concat
=
[]
for
weight_name
in
weight_names
:
weight
=
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
.
replace
(
original_name
,
weight_name
)))
weights_to_concat
.
append
(
weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)])
loaded_weight
=
torch
.
from_numpy
(
np
.
concatenate
(
weights_to_concat
,
axis
=
0
))
else
:
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
@
staticmethod
def
get_weights
(
model_name
:
str
,
path
:
str
):
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
model_name
,
"config.json"
)):
raise
ValueError
(
"LLaMA model's model_name has to be a path"
"to the huggingface model's directory."
)
path
=
os
.
path
.
join
(
model_name
,
f
"np"
)
path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
os
.
makedirs
(
path
,
exist_ok
=
True
)
lock_path
=
os
.
path
.
join
(
path
,
"file_lock"
)
lock
=
filelock
.
FileLock
(
lock_path
)
with
lock
:
test_weight_path
=
os
.
path
.
join
(
path
,
"model.embed_tokens.weight"
)
if
os
.
path
.
exists
(
test_weight_path
):
return
path
bin_files
=
glob
.
glob
(
os
.
path
.
join
(
model_name
,
"*.bin"
))
for
bin_file
in
tqdm
(
bin_files
,
desc
=
"Convert format"
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
tqdm
(
state
.
items
(),
leave
=
False
):
param_path
=
os
.
path
.
join
(
path
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
return
path
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"rotary_emb.inv_freq"
in
name
:
continue
is_attention_weight
=
False
for
stride_id
,
att_weight_name
in
enumerate
([
"q_proj"
,
"k_proj"
,
"v_proj"
]):
if
att_weight_name
not
in
name
:
continue
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
3
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
is_attention_weight
=
True
break
if
is_attention_weight
:
continue
is_gate_up_weight
=
False
for
stride_id
,
weight_name
in
enumerate
([
"gate_proj"
,
"up_proj"
]):
if
weight_name
not
in
name
:
continue
param
=
state_dict
[
name
.
replace
(
weight_name
,
"gate_up_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
2
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
is_gate_up_weight
=
True
break
if
is_gate_up_weight
:
continue
param
=
state_dict
[
name
]
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
)
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
...
...
cacheflow/models/model_utils.py
View file @
27f1410d
from
typing
import
Union
from
typing
import
Union
,
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -32,8 +32,9 @@ _MEMORY_ANALYZERS = {
def
get_model
(
model_name
:
str
,
dtype
:
Union
[
torch
.
dtype
,
str
],
path
:
str
,
cache_dir
:
Optional
[
str
]
,
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
)
->
nn
.
Module
:
torch_dtype
=
get_torch_dtype
(
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
...
...
@@ -46,15 +47,13 @@ def get_model(
model
=
model_class
(
config
)
model
=
model
.
cuda
()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
# random values to the weights.
model
.
initialize_dummy_weights
()
else
:
# Download model weights if it's not cached.
weights_dir
=
model_class
.
get_weights
(
model_name
,
path
=
path
)
# Create a model instance.
model
=
model_class
(
config
)
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
weights_dir
)
model
.
load_weights
(
model_name
,
cache_dir
,
use_np_cache
)
model
=
model
.
cuda
()
return
model
.
eval
(),
torch_dtype
raise
ValueError
(
f
'Unsupported model name:
{
model_name
}
'
)
...
...
cacheflow/models/opt.py
View file @
27f1410d
"""1D OPT model compatible with HuggingFace weights."""
import
os
import
glob
import
filelock
from
tqdm
import
tqdm
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
OPTConfig
from
huggingface_hub
import
snapshot_download
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
from
cacheflow.models.utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
cacheflow.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
...
...
@@ -257,73 +253,42 @@ class OPTForCausalLM(nn.Module):
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"fc1.weight"
,
"fc1.bias"
]
_row_parallel_weights
=
[
"out_proj.weight"
,
"fc2.weight"
]
def
load_weights
(
self
,
weights_path
:
str
):
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
param
in
state_dict
.
items
():
if
"lm_head_weight"
in
name
:
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"lm_head.weight"
in
name
:
continue
if
"qkv_proj"
in
name
:
if
name
.
startswith
(
"decoder."
):
name
=
"model."
+
name
is_attention_weight
=
False
for
stride_id
,
att_weight_name
in
enumerate
([
"q_proj"
,
"k_proj"
,
"v_proj"
]):
if
att_weight_name
not
in
name
:
continue
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
shard_size
=
param
.
shape
[
0
]
//
3
weights_to_concat
=
[]
for
weight_name
in
[
"q_proj"
,
"k_proj"
,
"v_proj"
]:
weight
=
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
.
replace
(
"qkv_proj"
,
weight_name
)))
weights_to_concat
.
append
(
weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)])
loaded_weight
=
torch
.
from_numpy
(
np
.
concatenate
(
weights_to_concat
,
axis
=
0
))
else
:
loaded_weight
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
weights_path
,
name
)))
for
p
in
self
.
_column_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
self
.
_row_parallel_weights
:
if
p
in
name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
@
staticmethod
def
get_weights
(
model_name
:
str
,
path
:
str
):
path
=
os
.
path
.
join
(
path
,
f
"
{
model_name
}
-np"
)
path
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
os
.
makedirs
(
path
,
exist_ok
=
True
)
lock_path
=
os
.
path
.
join
(
path
,
"file_lock"
)
lock
=
filelock
.
FileLock
(
lock_path
)
with
lock
:
test_weight_path
=
os
.
path
.
join
(
path
,
"model.decoder.embed_positions.weight"
)
if
os
.
path
.
exists
(
test_weight_path
):
return
path
folder
=
snapshot_download
(
model_name
,
allow_patterns
=
"*.bin"
,
cache_dir
=
os
.
path
.
join
(
path
,
"cache"
))
bin_files
=
glob
.
glob
(
os
.
path
.
join
(
folder
,
"*.bin"
))
for
bin_file
in
tqdm
(
bin_files
,
desc
=
"Convert format"
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
tqdm
(
state
.
items
(),
leave
=
False
):
if
name
.
startswith
(
"decoder."
):
name
=
"model."
+
name
param_path
=
os
.
path
.
join
(
path
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
return
path
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
is_attention_weight
=
True
break
if
is_attention_weight
:
continue
param
=
state_dict
[
name
]
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
)
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
...
...
cacheflow/models/utils.py
View file @
27f1410d
from
typing
import
Union
import
os
import
glob
import
json
import
filelock
from
typing
import
Union
,
Optional
import
numpy
as
np
import
torch
from
tqdm.auto
import
tqdm
from
huggingface_hub
import
snapshot_download
from
cacheflow.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
)
_STR_DTYPE_TO_TORCH_DTYPE
=
{
'half'
:
torch
.
half
,
...
...
@@ -22,3 +32,86 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype
=
get_torch_dtype
(
dtype
)
return
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
class
Disabledtqdm
(
tqdm
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
def
hf_model_weights_iterator
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
lock_dir
=
cache_dir
if
cache_dir
is
not
None
else
"/tmp"
lock_file_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
+
".lock"
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
))
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
not
is_local
:
with
lock
:
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
"*.bin"
,
cache_dir
=
cache_dir
,
tqdm_class
=
Disabledtqdm
)
else
:
hf_folder
=
model_name_or_path
hf_bin_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.bin"
))
if
use_np_cache
:
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder
=
os
.
path
.
join
(
hf_folder
,
'np'
)
os
.
makedirs
(
np_folder
,
exist_ok
=
True
)
weight_names_file
=
os
.
path
.
join
(
np_folder
,
'weight_names.json'
)
with
lock
:
if
not
os
.
path
.
exists
(
weight_names_file
):
weight_names
=
[]
for
bin_file
in
hf_bin_files
:
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
with
open
(
param_path
,
"wb"
)
as
f
:
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
weight_names
.
append
(
name
)
with
open
(
weight_names_file
,
'w'
)
as
f
:
json
.
dump
(
weight_names
,
f
)
with
open
(
weight_names_file
,
'r'
)
as
f
:
weight_names
=
json
.
load
(
f
)
for
name
in
weight_names
:
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
with
open
(
param_path
,
"rb"
)
as
f
:
param
=
np
.
load
(
f
)
yield
name
,
torch
.
from_numpy
(
param
)
else
:
for
bin_file
in
hf_bin_files
:
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
yield
name
,
param
def
load_tensor_parallel_weights
(
param
,
loaded_weight
,
param_name
,
column_parallel_weight_names
,
row_parallel_weight_names
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
for
p
in
column_parallel_weight_names
:
if
p
in
param_name
:
shard_size
=
param
.
shape
[
0
]
loaded_weight
=
loaded_weight
[
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
for
p
in
row_parallel_weight_names
:
if
p
in
param_name
:
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
:,
shard_size
*
tensor_model_parallel_rank
:
shard_size
*
(
tensor_model_parallel_rank
+
1
)]
break
assert
param
.
shape
==
loaded_weight
.
shape
param
.
data
.
copy_
(
loaded_weight
)
cacheflow/worker/controller.py
View file @
27f1410d
from
typing
import
Dict
,
List
,
Union
,
Tuple
from
typing
import
Dict
,
List
,
Union
,
Tuple
,
Optional
try
:
import
ray
...
...
@@ -29,8 +29,9 @@ class Controller:
num_cpu_blocks
:
int
,
dtype
:
str
,
seed
:
int
,
model_path
:
str
,
cache_dir
:
Optional
[
str
]
,
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
max_num_batched_tokens
:
int
,
use_ray
:
bool
,
)
->
None
:
...
...
@@ -66,8 +67,9 @@ class Controller:
world_size
=
world_size
,
tensor_parallel_size
=
tensor_parallel_size
,
pipeline_parallel_size
=
pipeline_parallel_size
,
model_path
=
model_path
,
cache_dir
=
cache_dir
,
use_dummy_weights
=
use_dummy_weights
,
use_np_cache
=
use_np_cache
,
max_num_batched_tokens
=
max_num_batched_tokens
,
)
self
.
workers
.
append
(
worker
)
...
...
cacheflow/worker/worker.py
View file @
27f1410d
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
,
Optional
import
torch
...
...
@@ -28,8 +28,9 @@ class Worker:
distributed_init_method
:
str
,
rank
:
int
,
world_size
:
int
,
model_path
:
str
,
cache_dir
:
Optional
[
str
]
,
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
max_num_batched_tokens
:
int
,
tensor_parallel_size
:
int
=
1
,
pipeline_parallel_size
:
int
=
1
,
...
...
@@ -45,7 +46,8 @@ class Worker:
# Initialize the model.
self
.
model
,
self
.
dtype
=
get_model
(
model_name
,
dtype
=
dtype
,
path
=
model_path
,
use_dummy_weights
=
use_dummy_weights
)
model_name
,
dtype
=
dtype
,
cache_dir
=
cache_dir
,
use_dummy_weights
=
use_dummy_weights
,
use_np_cache
=
use_np_cache
)
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
initialize_all_reduce_launcher
(
...
...
simple_server.py
View file @
27f1410d
import
argparse
from
typing
import
List
from
cacheflow.master.simple_frontend
import
SimpleFrontend
from
cacheflow.master.server
import
(
Server
,
add_server_arguments
,
process_server_arguments
,
initialize_cluster
)
from
cacheflow.master.server
import
(
add_server_arguments
,
process_server_arguments
,
init_local_server_and_frontend_with_arguments
)
from
cacheflow.sampling_params
import
SamplingParams
from
cacheflow.utils
import
get_gpu_memory
,
get_cpu_memory
def
main
(
args
:
argparse
.
Namespace
):
# TODO(zhuohan): Support pipeline parallelism.
assert
args
.
pipeline_parallel_size
==
1
,
(
'Pipeline parallelism is not supported yet.'
)
(
num_nodes
,
num_devices_per_node
,
distributed_init_method
,
all_stage_devices
)
=
(
initialize_cluster
(
use_ray
=
args
.
use_ray
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
))
# Create a server.
server
=
Server
(
model
=
args
.
model
,
model_path
=
args
.
model_path
,
use_dummy_weights
=
args
.
use_dummy_weights
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
dtype
=
args
.
dtype
,
seed
=
args
.
seed
,
swap_space
=
args
.
swap_space
,
max_num_batched_tokens
=
args
.
max_num_batched_tokens
,
max_num_sequences
=
args
.
max_num_sequences
,
num_nodes
=
num_nodes
,
num_devices_per_node
=
num_devices_per_node
,
distributed_init_method
=
distributed_init_method
,
all_stage_devices
=
all_stage_devices
,
gpu_memory
=
get_gpu_memory
(),
cpu_memory
=
get_cpu_memory
(),
use_ray
=
args
.
use_ray
,
)
# Create a frontend.
frontend
=
SimpleFrontend
(
model_name
=
args
.
model
,
block_size
=
args
.
block_size
,
)
server
,
frontend
=
init_local_server_and_frontend_with_arguments
(
args
)
# Test the following inputs.
test_inputs
=
[
(
'Ion Stoica is a'
,
{
'n'
:
4
,
'use_beam_search'
:
True
,
'temperature'
:
0.0
}),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment