Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
0d93f156
Unverified
Commit
0d93f156
authored
Aug 30, 2023
by
JFDuan
Committed by
GitHub
Aug 30, 2023
Browse files
Accelerate LLaMA model loading (#234)
parent
becd7a56
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
191 additions
and
113 deletions
+191
-113
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+9
-14
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+12
-17
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+7
-11
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+7
-11
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+7
-14
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+13
-16
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+8
-11
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+128
-19
No files found.
vllm/model_executor/models/aquila.py
View file @
0d93f156
...
...
@@ -34,8 +34,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
...
...
@@ -280,8 +281,7 @@ class AquilaForCausalLM(nn.Module):
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
...
...
@@ -309,16 +309,6 @@ class AquilaForCausalLM(nn.Module):
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
param
=
state_dict
[
name
]
# Consider padding in the vocab size.
padded_vocab_size
=
(
param
.
shape
[
0
]
*
tp_size
)
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
is_attention_weight
=
False
for
weight_name
,
shard_size
,
offset
in
attention_weight_specs
:
if
weight_name
not
in
name
:
...
...
@@ -356,6 +346,11 @@ class AquilaForCausalLM(nn.Module):
continue
param
=
state_dict
[
name
]
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
tensor_model_parallel_rank
)
continue
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
...
...
vllm/model_executor/models/baichuan.py
View file @
0d93f156
...
...
@@ -32,10 +32,12 @@ from vllm.sequence import SequenceOutputs
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
,
PagedAttentionWithALiBi
from
vllm.model_executor.layers.attention
import
(
PagedAttentionWithRoPE
,
PagedAttentionWithALiBi
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
...
...
@@ -295,10 +297,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
]
_column_parallel_weights
=
[]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
def
load_weights
(
self
,
...
...
@@ -314,16 +313,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
# Consider padding in the vocab size.
param
=
state_dict
[
name
]
padded_vocab_size
=
param
.
shape
[
0
]
*
tp_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
if
"W_pack"
in
name
:
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
...
...
@@ -355,6 +344,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
continue
param
=
state_dict
[
name
]
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
tp_rank
)
continue
load_tensor_parallel_weights
(
param
,
loaded_weight
,
...
...
vllm/model_executor/models/gpt2.py
View file @
0d93f156
...
...
@@ -31,8 +31,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
...
...
@@ -224,7 +225,7 @@ class GPT2LMHeadModel(nn.Module):
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_column_parallel_weights
=
[
"c_fc.weight"
,
"c_fc.bias"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
self
,
...
...
@@ -261,14 +262,9 @@ class GPT2LMHeadModel(nn.Module):
param
=
state_dict
[
name
]
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
padded_vocab_size
=
(
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
)
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
tensor_model_parallel_rank
)
continue
# For the fused QKV linear layer, manually shard the weights.
if
"c_attn"
in
name
:
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
0d93f156
...
...
@@ -32,8 +32,9 @@ from vllm.model_executor.input_metadata import InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
...
...
@@ -252,7 +253,7 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"c_fc.weight"
,
"c_fc.bias"
]
_column_parallel_weights
=
[
"c_fc.weight"
,
"c_fc.bias"
]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
self
,
...
...
@@ -328,14 +329,9 @@ class GPTBigCodeForCausalLM(nn.Module):
param
=
state_dict
[
name
]
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
tensor_model_parallel_rank
)
continue
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
...
...
vllm/model_executor/models/internlm.py
View file @
0d93f156
...
...
@@ -14,8 +14,9 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
)
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -225,8 +226,7 @@ class InternLMForCausalLM(nn.Module):
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
...
...
@@ -234,8 +234,6 @@ class InternLMForCausalLM(nn.Module):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
...
...
@@ -246,14 +244,9 @@ class InternLMForCausalLM(nn.Module):
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
param
=
state_dict
[
name
]
# Consider padding in the vocab size.
padded_vocab_size
=
(
param
.
shape
[
0
]
*
tensor_model_parallel_world_size
)
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
tensor_model_parallel_rank
)
continue
is_attention_weight
=
False
for
stride_id
,
att_weight_name
in
enumerate
(
...
...
vllm/model_executor/models/llama.py
View file @
0d93f156
...
...
@@ -36,8 +36,9 @@ from vllm.model_executor.layers.activation import SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.weight_utils
import
(
load_tensor_parallel_weights
,
load_padded_tensor_parallel_vocab
,
hf_model_weights_iterator
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
...
...
@@ -263,15 +264,15 @@ class LlamaForCausalLM(nn.Module):
return
next_tokens
_column_parallel_weights
=
[
"embed_tokens.weight"
,
"lm_head.weight"
,
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
"qkv_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
]
_row_parallel_weights
=
[
"o_proj.weight"
,
"down_proj.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
,
use_safetensor
:
bool
=
True
):
tp_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
q_proj_shard_size
=
(
self
.
config
.
hidden_size
//
tp_size
)
...
...
@@ -288,20 +289,10 @@ class LlamaForCausalLM(nn.Module):
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
use_np_cache
,
use_safetensor
):
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
param
=
state_dict
[
name
]
# Consider padding in the vocab size.
padded_vocab_size
=
(
param
.
shape
[
0
]
*
tp_size
)
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
is_attention_weight
=
False
for
weight_name
,
shard_size
,
offset
in
attention_weight_specs
:
if
weight_name
not
in
name
:
...
...
@@ -339,6 +330,12 @@ class LlamaForCausalLM(nn.Module):
continue
param
=
state_dict
[
name
]
if
"embed_tokens"
in
name
or
"lm_head"
in
name
:
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
tensor_model_parallel_rank
)
continue
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
...
...
vllm/model_executor/models/qwen.py
View file @
0d93f156
...
...
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
,
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
...
@@ -241,7 +242,7 @@ class QWenLMHeadModel(nn.Module):
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"lm_head.weight"
]
_column_parallel_weights
=
[]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
...
...
@@ -259,16 +260,6 @@ class QWenLMHeadModel(nn.Module):
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"wte"
in
name
or
"lm_head"
in
name
:
# Consider padding in the vocab size.
param
=
state_dict
[
name
]
padded_vocab_size
=
param
.
shape
[
0
]
*
tp_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
extra_rows
=
extra_rows
.
to
(
loaded_weight
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
extra_rows
],
dim
=
0
)
if
"c_attn"
in
name
:
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
...
...
@@ -306,6 +297,12 @@ class QWenLMHeadModel(nn.Module):
continue
param
=
state_dict
[
name
]
if
"wte"
in
name
or
"lm_head"
in
name
:
load_padded_tensor_parallel_vocab
(
param
,
loaded_weight
,
tp_rank
)
continue
load_tensor_parallel_weights
(
param
,
loaded_weight
,
...
...
vllm/model_executor/weight_utils.py
View file @
0d93f156
...
...
@@ -3,13 +3,19 @@ import filelock
import
glob
import
json
import
os
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
from
collections
import
defaultdict
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
,
Any
from
huggingface_hub
import
snapshot_download
from
safetensors.torch
import
load_file
,
save_file
,
safe_open
import
numpy
as
np
import
torch
from
tqdm.auto
import
tqdm
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
Disabledtqdm
(
tqdm
):
...
...
@@ -17,43 +23,118 @@ class Disabledtqdm(tqdm):
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
def
hf_model_weights_iterator
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
,
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
# Prepare file lock directory to prevent multiple processes from
# downloading the same model weights at the same time.
def
get_lock
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
):
lock_dir
=
cache_dir
if
cache_dir
is
not
None
else
"/tmp"
lock_file_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
+
".lock"
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
))
return
lock
def
_shared_pointers
(
tensors
):
ptrs
=
defaultdict
(
list
)
for
k
,
v
in
tensors
.
items
():
ptrs
[
v
.
data_ptr
()].
append
(
k
)
failing
=
[]
for
_
,
names
in
ptrs
.
items
():
if
len
(
names
)
>
1
:
failing
.
append
(
names
)
return
failing
def
convert_bin_to_safetensor_file
(
pt_filename
:
str
,
sf_filename
:
str
,
):
loaded
=
torch
.
load
(
pt_filename
,
map_location
=
"cpu"
)
if
"state_dict"
in
loaded
:
loaded
=
loaded
[
"state_dict"
]
shared
=
_shared_pointers
(
loaded
)
for
shared_weights
in
shared
:
for
name
in
shared_weights
[
1
:]:
loaded
.
pop
(
name
)
# For tensors to be contiguous
loaded
=
{
k
:
v
.
contiguous
()
for
k
,
v
in
loaded
.
items
()}
dirname
=
os
.
path
.
dirname
(
sf_filename
)
os
.
makedirs
(
dirname
,
exist_ok
=
True
)
save_file
(
loaded
,
sf_filename
,
metadata
=
{
"format"
:
"pt"
})
# check file size
sf_size
=
os
.
stat
(
sf_filename
).
st_size
pt_size
=
os
.
stat
(
pt_filename
).
st_size
if
(
sf_size
-
pt_size
)
/
pt_size
>
0.01
:
raise
RuntimeError
(
f
"""The file size different is more than 1%:
-
{
sf_filename
}
:
{
sf_size
}
-
{
pt_filename
}
:
{
pt_size
}
"""
)
# check if the tensors are the same
reloaded
=
load_file
(
sf_filename
)
for
k
in
loaded
:
pt_tensor
=
loaded
[
k
]
sf_tensor
=
reloaded
[
k
]
if
not
torch
.
equal
(
pt_tensor
,
sf_tensor
):
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
def
prepare_hf_model_weights
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_safetensor
:
bool
=
False
,
):
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
allow_patterns
=
"*.safetensors"
if
use_safetensor
else
"*.bin"
if
not
is_local
:
with
lock
:
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
"*.bin"
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
Disabledtqdm
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
allow_patterns
))
if
not
use_safetensor
:
hf_weights_files
=
[
x
for
x
in
hf_weights_files
if
not
x
.
endswith
(
"training_args.bin"
)
]
if
len
(
hf_weights_files
)
==
0
and
use_safetensor
:
logger
.
warning
(
"No *.safetensors files found, "
"fall back to *.bin files"
)
return
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
use_safetensor
=
False
)
return
hf_folder
,
hf_weights_files
,
use_safetensor
hf_bin_files
=
[
x
for
x
in
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.bin"
))
if
not
x
.
endswith
(
"training_args.bin"
)
]
def
hf_model_weights_iterator
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
,
use_safetensor
:
bool
=
False
,
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
hf_folder
,
hf_weights_files
,
use_safetensor
=
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
use_safetensor
=
use_safetensor
)
if
use_np_cache
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensor
is
False
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder
=
os
.
path
.
join
(
hf_folder
,
"np"
)
os
.
makedirs
(
np_folder
,
exist_ok
=
True
)
weight_names_file
=
os
.
path
.
join
(
np_folder
,
"weight_names.json"
)
with
lock
:
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
if
not
os
.
path
.
exists
(
weight_names_file
):
weight_names
=
[]
for
bin_file
in
hf_
bin
_files
:
for
bin_file
in
hf_
weights
_files
:
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
param_path
=
os
.
path
.
join
(
np_folder
,
name
)
...
...
@@ -71,8 +152,14 @@ def hf_model_weights_iterator(
with
open
(
param_path
,
"rb"
)
as
f
:
param
=
np
.
load
(
f
)
yield
name
,
torch
.
from_numpy
(
param
)
elif
use_safetensor
:
for
st_file
in
hf_weights_files
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
param
=
f
.
get_slice
(
name
)
yield
name
,
param
else
:
for
bin_file
in
hf_
bin
_files
:
for
bin_file
in
hf_
weights
_files
:
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
yield
name
,
param
...
...
@@ -80,9 +167,26 @@ def hf_model_weights_iterator(
torch
.
cuda
.
empty_cache
()
def
load_padded_tensor_parallel_vocab
(
param
:
torch
.
Tensor
,
loaded_weight
:
Any
,
# `torch.Tensor` or `PySafeSlice`
tensor_model_parallel_rank
:
int
,
)
->
None
:
shard_size
=
param
.
shape
[
0
]
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
loaded_weight
=
loaded_weight
[
start_idx
:
end_idx
]
# convert PySafeSlice object to torch.Tensor
if
not
isinstance
(
loaded_weight
,
torch
.
Tensor
):
loaded_weight
=
loaded_weight
[:]
param
[:
loaded_weight
.
shape
[
0
]].
copy_
(
loaded_weight
)
def
load_tensor_parallel_weights
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
Any
,
# `
torch.Tensor
` or `PySafeSlice`
param_name
:
str
,
column_parallel_weight_names
:
List
[
str
],
row_parallel_weight_names
:
List
[
str
],
...
...
@@ -102,6 +206,11 @@ def load_tensor_parallel_weights(
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
loaded_weight
=
loaded_weight
[:,
start_idx
:
end_idx
]
break
# convert PySafeSlice object to torch.Tensor
if
not
isinstance
(
loaded_weight
,
torch
.
Tensor
):
loaded_weight
=
loaded_weight
[:]
assert
param
.
shape
==
loaded_weight
.
shape
,
(
f
"
{
param_name
}
shape mismatch between model and checkpoint: "
f
"
{
param
.
shape
}
!=
{
loaded_weight
.
shape
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment