Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c957c741
Unverified
Commit
c957c741
authored
Sep 07, 2023
by
Zhuohan Li
Committed by
GitHub
Sep 07, 2023
Browse files
Enable safetensors loading for all models (#974)
parent
c07ece5c
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
143 additions
and
83 deletions
+143
-83
vllm/config.py
vllm/config.py
+23
-7
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+18
-13
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+1
-2
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+2
-2
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+2
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+6
-4
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+2
-2
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+5
-3
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+6
-4
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+5
-4
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+2
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+2
-2
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+2
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-3
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+5
-4
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+2
-2
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+5
-2
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+53
-23
No files found.
vllm/config.py
View file @
c957c741
...
@@ -24,9 +24,16 @@ class ModelConfig:
...
@@ -24,9 +24,16 @@ class ModelConfig:
downloading the model and tokenizer.
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
load_format: The format of the model weights to load:
This can increase the disk usage by up to 2x.
"auto" will try to load the weights in the safetensors format and
use_dummy_weights: Use dummy values for model weights (for profiling).
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
for BF16 models.
...
@@ -40,8 +47,7 @@ class ModelConfig:
...
@@ -40,8 +47,7 @@ class ModelConfig:
tokenizer_mode
:
str
,
tokenizer_mode
:
str
,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
download_dir
:
Optional
[
str
],
download_dir
:
Optional
[
str
],
use_np_weights
:
bool
,
load_format
:
str
,
use_dummy_weights
:
bool
,
dtype
:
str
,
dtype
:
str
,
seed
:
int
,
seed
:
int
,
)
->
None
:
)
->
None
:
...
@@ -50,14 +56,24 @@ class ModelConfig:
...
@@ -50,14 +56,24 @@ class ModelConfig:
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
self
.
trust_remote_code
=
trust_remote_code
self
.
trust_remote_code
=
trust_remote_code
self
.
download_dir
=
download_dir
self
.
download_dir
=
download_dir
self
.
use_np_weights
=
use_np_weights
self
.
load_format
=
load_format
self
.
use_dummy_weights
=
use_dummy_weights
self
.
seed
=
seed
self
.
seed
=
seed
self
.
hf_config
=
get_config
(
model
,
trust_remote_code
)
self
.
hf_config
=
get_config
(
model
,
trust_remote_code
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
dtype
)
self
.
_verify_load_format
()
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
def
_verify_load_format
(
self
)
->
None
:
load_format
=
self
.
load_format
.
lower
()
if
load_format
not
in
[
"auto"
,
"pt"
,
"safetensors"
,
"npcache"
,
"dummy"
]:
raise
ValueError
(
f
"Unknown load format:
{
self
.
load_format
}
. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'."
)
self
.
load_format
=
load_format
def
_verify_tokenizer_mode
(
self
)
->
None
:
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
]:
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
]:
...
...
vllm/engine/arg_utils.py
View file @
c957c741
...
@@ -15,8 +15,7 @@ class EngineArgs:
...
@@ -15,8 +15,7 @@ class EngineArgs:
tokenizer_mode
:
str
=
'auto'
tokenizer_mode
:
str
=
'auto'
trust_remote_code
:
bool
=
False
trust_remote_code
:
bool
=
False
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
use_np_weights
:
bool
=
False
load_format
:
str
=
'auto'
use_dummy_weights
:
bool
=
False
dtype
:
str
=
'auto'
dtype
:
str
=
'auto'
seed
:
int
=
0
seed
:
int
=
0
worker_use_ray
:
bool
=
False
worker_use_ray
:
bool
=
False
...
@@ -65,14 +64,21 @@ class EngineArgs:
...
@@ -65,14 +64,21 @@ class EngineArgs:
help
=
'directory to download and load the weights, '
help
=
'directory to download and load the weights, '
'default to the default cache dir of '
'default to the default cache dir of '
'huggingface'
)
'huggingface'
)
parser
.
add_argument
(
'--use-np-weights'
,
parser
.
add_argument
(
action
=
'store_true'
,
'--load-format'
,
help
=
'save a numpy copy of model weights for '
type
=
str
,
'faster loading. This can increase the disk '
default
=
EngineArgs
.
load_format
,
'usage by up to 2x.'
)
choices
=
[
'auto'
,
'pt'
,
'safetensors'
,
'npcache'
,
'dummy'
],
parser
.
add_argument
(
'--use-dummy-weights'
,
help
=
'The format of the model weights to load. '
action
=
'store_true'
,
'"auto" will try to load the weights in the safetensors format '
help
=
'use dummy values for model weights'
)
'and fall back to the pytorch bin format if safetensors format '
'is not available. '
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.'
)
# TODO(woosuk): Support FP32.
# TODO(woosuk): Support FP32.
parser
.
add_argument
(
parser
.
add_argument
(
'--dtype'
,
'--dtype'
,
...
@@ -146,9 +152,8 @@ class EngineArgs:
...
@@ -146,9 +152,8 @@ class EngineArgs:
# Initialize the configs.
# Initialize the configs.
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
download_dir
,
self
.
use_np_weights
,
self
.
download_dir
,
self
.
load_format
,
self
.
use_dummy_weights
,
self
.
dtype
,
self
.
dtype
,
self
.
seed
)
self
.
seed
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
)
self
.
swap_space
)
...
...
vllm/engine/llm_engine.py
View file @
c957c741
...
@@ -76,9 +76,8 @@ class LLMEngine:
...
@@ -76,9 +76,8 @@ class LLMEngine:
f
"tokenizer_mode=
{
model_config
.
tokenizer_mode
}
, "
f
"tokenizer_mode=
{
model_config
.
tokenizer_mode
}
, "
f
"trust_remote_code=
{
model_config
.
trust_remote_code
}
, "
f
"trust_remote_code=
{
model_config
.
trust_remote_code
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"use_dummy_weights=
{
model_config
.
use_dummy_weights
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"
use_np_weights
=
{
model_config
.
use_np_weights
}
, "
f
"
load_format
=
{
model_config
.
load_format
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
f
"seed=
{
model_config
.
seed
}
)"
)
f
"seed=
{
model_config
.
seed
}
)"
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
...
...
vllm/model_executor/model_loader.py
View file @
c957c741
...
@@ -56,7 +56,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
...
@@ -56,7 +56,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# Create a model instance.
# Create a model instance.
# The weights will be initialized as empty tensors.
# The weights will be initialized as empty tensors.
model
=
model_class
(
model_config
.
hf_config
)
model
=
model_class
(
model_config
.
hf_config
)
if
model_config
.
use_dummy_weights
:
if
model_config
.
load_format
==
"dummy"
:
model
=
model
.
cuda
()
model
=
model
.
cuda
()
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
...
@@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
...
@@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
else
:
else
:
# Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
use_np_weights
)
model_config
.
load_format
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
return
model
.
eval
()
return
model
.
eval
()
vllm/model_executor/models/aquila.py
View file @
c957c741
...
@@ -288,7 +288,7 @@ class AquilaForCausalLM(nn.Module):
...
@@ -288,7 +288,7 @@ class AquilaForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
q_proj_shard_size
=
(
self
.
config
.
hidden_size
//
tp_size
)
q_proj_shard_size
=
(
self
.
config
.
hidden_size
//
tp_size
)
...
@@ -305,7 +305,7 @@ class AquilaForCausalLM(nn.Module):
...
@@ -305,7 +305,7 @@ class AquilaForCausalLM(nn.Module):
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
...
vllm/model_executor/models/baichuan.py
View file @
c957c741
...
@@ -35,8 +35,8 @@ from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
...
@@ -35,8 +35,8 @@ from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi
)
PagedAttentionWithALiBi
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
...
@@ -303,16 +303,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -303,16 +303,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
loaded_weight
=
convert_pyslice_to_tensor
(
loaded_weight
)
if
"W_pack"
in
name
:
if
"W_pack"
in
name
:
total_num_heads
=
self
.
config
.
num_attention_heads
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
hidden_size
=
self
.
config
.
hidden_size
...
...
vllm/model_executor/models/bloom.py
View file @
c957c741
...
@@ -279,11 +279,11 @@ class BloomForCausalLM(nn.Module):
...
@@ -279,11 +279,11 @@ class BloomForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
name
==
"lm_head.weight"
:
if
name
==
"lm_head.weight"
:
# Since hidden_states are parallelized, we need to
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.
# load lm_head.weight in parallel.
...
...
vllm/model_executor/models/falcon.py
View file @
c957c741
...
@@ -31,7 +31,8 @@ from vllm.model_executor.layers.attention import (PagedAttention,
...
@@ -31,7 +31,8 @@ from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi
,
PagedAttentionWithALiBi
,
PagedAttentionWithRoPE
)
PagedAttentionWithRoPE
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
from
vllm.model_executor.weight_utils
import
(
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
@@ -419,7 +420,7 @@ class FalconForCausalLM(nn.Module):
...
@@ -419,7 +420,7 @@ class FalconForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tp_size
=
(
get_tensor_model_parallel_world_size
())
tp_size
=
(
get_tensor_model_parallel_world_size
())
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
...
@@ -451,8 +452,9 @@ class FalconForCausalLM(nn.Module):
...
@@ -451,8 +452,9 @@ class FalconForCausalLM(nn.Module):
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"query_key_value"
in
name
:
if
"query_key_value"
in
name
:
loaded_weight
=
convert_pyslice_to_tensor
(
loaded_weight
)
loaded_weight_size
=
loaded_weight
.
size
()
loaded_weight_size
=
loaded_weight
.
size
()
loaded_weight
=
loaded_weight
.
view
(
loaded_weight
=
loaded_weight
.
view
(
total_num_kv_heads
,
num_query_heads_per_kv_head
+
2
,
total_num_kv_heads
,
num_query_heads_per_kv_head
+
2
,
...
...
vllm/model_executor/models/gpt2.py
View file @
c957c741
...
@@ -32,8 +32,8 @@ from vllm.model_executor.layers.activation import get_act_fn
...
@@ -32,8 +32,8 @@ from vllm.model_executor.layers.activation import get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
...
@@ -231,14 +231,14 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -231,14 +231,14 @@ class GPT2LMHeadModel(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tensor_model_parallel_world_size
=
(
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"lm_head.weight"
in
name
:
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
# linear layer.
...
@@ -251,6 +251,8 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -251,6 +251,8 @@ class GPT2LMHeadModel(nn.Module):
if
not
name
.
startswith
(
"transformer."
):
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
name
=
"transformer."
+
name
loaded_weight
=
convert_pyslice_to_tensor
(
loaded_weight
)
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Because of this, we need to transpose the weights.
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
c957c741
...
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.activation import get_act_fn
...
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.activation import get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
...
@@ -259,14 +259,14 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -259,14 +259,14 @@ class GPTBigCodeForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tensor_model_parallel_world_size
=
(
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"lm_head.weight"
in
name
:
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
# linear layer.
...
@@ -295,6 +295,7 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -295,6 +295,7 @@ class GPTBigCodeForCausalLM(nn.Module):
head_start
=
tensor_model_parallel_rank
*
num_heads
head_start
=
tensor_model_parallel_rank
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
loaded_weight
=
convert_pyslice_to_tensor
(
loaded_weight
)
wq
,
wk
,
wv
=
torch
.
split
(
wq
,
wk
,
wv
=
torch
.
split
(
loaded_weight
,
[
hidden_size
,
total_kv_size
,
total_kv_size
],
loaded_weight
,
[
hidden_size
,
total_kv_size
,
total_kv_size
],
dim
=
0
)
dim
=
0
)
...
...
vllm/model_executor/models/gpt_j.py
View file @
c957c741
...
@@ -222,11 +222,11 @@ class GPTJForCausalLM(nn.Module):
...
@@ -222,11 +222,11 @@ class GPTJForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"attn.bias"
in
name
or
"attn.masked_bias"
in
name
:
if
"attn.bias"
in
name
or
"attn.masked_bias"
in
name
:
continue
continue
...
...
vllm/model_executor/models/gpt_neox.py
View file @
c957c741
...
@@ -231,11 +231,11 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -231,11 +231,11 @@ class GPTNeoXForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
or
"rotary_emb.inv_freq"
in
name
):
or
"rotary_emb.inv_freq"
in
name
):
continue
continue
...
...
vllm/model_executor/models/internlm.py
View file @
c957c741
...
@@ -233,12 +233,12 @@ class InternLMForCausalLM(nn.Module):
...
@@ -233,12 +233,12 @@ class InternLMForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
...
vllm/model_executor/models/llama.py
View file @
c957c741
...
@@ -271,8 +271,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -271,8 +271,7 @@ class LlamaForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
,
load_format
:
str
=
"auto"
):
use_safetensor
:
bool
=
True
):
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
q_proj_shard_size
=
(
self
.
config
.
hidden_size
//
tp_size
)
q_proj_shard_size
=
(
self
.
config
.
hidden_size
//
tp_size
)
...
@@ -289,7 +288,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -289,7 +288,7 @@ class LlamaForCausalLM(nn.Module):
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
,
use_safetensor
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
...
vllm/model_executor/models/mpt.py
View file @
c957c741
...
@@ -10,7 +10,8 @@ from vllm.model_executor.input_metadata import InputMetadata
...
@@ -10,7 +10,8 @@ from vllm.model_executor.input_metadata import InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttentionWithALiBi
from
vllm.model_executor.layers.attention
import
PagedAttentionWithALiBi
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
from
vllm.model_executor.weight_utils
import
(
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
@@ -243,12 +244,12 @@ class MPTForCausalLM(nn.Module):
...
@@ -243,12 +244,12 @@ class MPTForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"Wqkv"
in
name
:
if
"Wqkv"
in
name
:
# NOTE(woosuk): MPT's fused QKV has the shape of
# NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# [3 * num_heads * head_size, hidden_size].
...
@@ -260,7 +261,7 @@ class MPTForCausalLM(nn.Module):
...
@@ -260,7 +261,7 @@ class MPTForCausalLM(nn.Module):
num_heads
=
total_num_heads
//
tp_world_size
num_heads
=
total_num_heads
//
tp_world_size
head_start
=
tp_rank
*
num_heads
head_start
=
tp_rank
*
num_heads
head_end
=
(
tp_rank
+
1
)
*
num_heads
head_end
=
(
tp_rank
+
1
)
*
num_heads
loaded_weight
=
convert_pyslice_to_tensor
(
loaded_weight
)
if
name
.
endswith
(
".weight"
):
if
name
.
endswith
(
".weight"
):
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
loaded_weight
=
loaded_weight
.
view
(
3
,
total_num_heads
,
head_size
,
hidden_size
)
head_size
,
hidden_size
)
...
...
vllm/model_executor/models/opt.py
View file @
c957c741
...
@@ -297,12 +297,12 @@ class OPTForCausalLM(nn.Module):
...
@@ -297,12 +297,12 @@ class OPTForCausalLM(nn.Module):
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
load_format
:
str
=
"auto"
):
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"lm_head.weight"
in
name
:
if
"lm_head.weight"
in
name
:
continue
continue
...
...
vllm/model_executor/models/qwen.py
View file @
c957c741
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
from
vllm.model_executor.weight_utils
import
(
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
,
load_tensor_parallel_weights
,
...
@@ -249,17 +250,19 @@ class QWenLMHeadModel(nn.Module):
...
@@ -249,17 +250,19 @@ class QWenLMHeadModel(nn.Module):
self
,
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
,
load_format
:
str
=
"auto"
,
):
):
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_world_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
model_name_or_path
,
cache_dir
,
load_format
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
loaded_weight
=
convert_pyslice_to_tensor
(
loaded_weight
)
if
"c_attn"
in
name
:
if
"c_attn"
in
name
:
total_num_heads
=
self
.
config
.
num_attention_heads
total_num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
hidden_size
=
self
.
config
.
hidden_size
...
...
vllm/model_executor/weight_utils.py
View file @
c957c741
...
@@ -81,11 +81,12 @@ def convert_bin_to_safetensor_file(
...
@@ -81,11 +81,12 @@ def convert_bin_to_safetensor_file(
def
prepare_hf_model_weights
(
def
prepare_hf_model_weights
(
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_safetensor
:
bool
=
False
,
use_safetensors
:
bool
=
False
,
fall_back_to_pt
:
bool
=
True
,
):
):
# Download model weights from huggingface.
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
allow_patterns
=
"*.safetensors"
if
use_safetensor
else
"*.bin"
allow_patterns
=
"*.safetensors"
if
use_safetensor
s
else
"*.bin"
if
not
is_local
:
if
not
is_local
:
# Use file lock to prevent multiple processes from
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
# downloading the same model weights at the same time.
...
@@ -97,32 +98,53 @@ def prepare_hf_model_weights(
...
@@ -97,32 +98,53 @@ def prepare_hf_model_weights(
else
:
else
:
hf_folder
=
model_name_or_path
hf_folder
=
model_name_or_path
hf_weights_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
allow_patterns
))
hf_weights_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
allow_patterns
))
if
not
use_safetensor
:
if
not
use_safetensor
s
:
hf_weights_files
=
[
hf_weights_files
=
[
x
for
x
in
hf_weights_files
if
not
x
.
endswith
(
"training_args.bin"
)
x
for
x
in
hf_weights_files
if
not
x
.
endswith
(
"training_args.bin"
)
]
]
if
len
(
hf_weights_files
)
==
0
and
use_safetensor
:
if
len
(
hf_weights_files
)
==
0
and
use_safetensors
and
fall_back_to_pt
:
logger
.
warning
(
"No *.safetensors files found, "
"fall back to *.bin files"
)
return
prepare_hf_model_weights
(
model_name_or_path
,
return
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
use_safetensor
=
False
)
use_safetensors
=
False
,
return
hf_folder
,
hf_weights_files
,
use_safetensor
fall_back_to_pt
=
False
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
hf_model_weights_iterator
(
def
hf_model_weights_iterator
(
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
,
load_format
:
str
=
"auto"
,
use_safetensor
:
bool
=
False
,
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
hf_folder
,
hf_weights_files
,
use_safetensor
=
prepare_hf_model_weights
(
use_safetensors
=
False
model_name_or_path
,
cache_dir
=
cache_dir
,
use_safetensor
=
use_safetensor
)
use_np_cache
=
False
fall_back_to_pt
=
False
if
load_format
==
"auto"
:
use_safetensors
=
True
fall_back_to_pt
=
True
elif
load_format
==
"safetensors"
:
use_safetensors
=
True
elif
load_format
==
"pt"
:
pass
elif
load_format
==
"npcache"
:
use_np_cache
=
True
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
hf_folder
,
hf_weights_files
,
use_safetensors
=
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
use_safetensors
=
use_safetensors
,
fall_back_to_pt
=
fall_back_to_pt
)
if
use_np_cache
:
if
use_np_cache
:
# Currently np_cache only support *.bin checkpoints
# Currently np_cache only support *.bin checkpoints
assert
use_safetensor
is
False
assert
use_safetensor
s
is
False
# Convert the model weights from torch tensors to numpy arrays for
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
# faster loading.
...
@@ -152,7 +174,7 @@ def hf_model_weights_iterator(
...
@@ -152,7 +174,7 @@ def hf_model_weights_iterator(
with
open
(
param_path
,
"rb"
)
as
f
:
with
open
(
param_path
,
"rb"
)
as
f
:
param
=
np
.
load
(
f
)
param
=
np
.
load
(
f
)
yield
name
,
torch
.
from_numpy
(
param
)
yield
name
,
torch
.
from_numpy
(
param
)
elif
use_safetensor
:
elif
use_safetensor
s
:
for
st_file
in
hf_weights_files
:
for
st_file
in
hf_weights_files
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
...
@@ -167,6 +189,21 @@ def hf_model_weights_iterator(
...
@@ -167,6 +189,21 @@ def hf_model_weights_iterator(
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
convert_pyslice_to_tensor
(
x
:
Any
)
->
torch
.
Tensor
:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if
not
isinstance
(
x
,
torch
.
Tensor
):
x
=
x
[:]
return
x
def
load_padded_tensor_parallel_vocab
(
def
load_padded_tensor_parallel_vocab
(
param
:
torch
.
Tensor
,
param
:
torch
.
Tensor
,
loaded_weight
:
Any
,
# `torch.Tensor` or `PySafeSlice`
loaded_weight
:
Any
,
# `torch.Tensor` or `PySafeSlice`
...
@@ -176,11 +213,7 @@ def load_padded_tensor_parallel_vocab(
...
@@ -176,11 +213,7 @@ def load_padded_tensor_parallel_vocab(
start_idx
=
tensor_model_parallel_rank
*
shard_size
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
loaded_weight
=
loaded_weight
[
start_idx
:
end_idx
]
loaded_weight
=
loaded_weight
[
start_idx
:
end_idx
]
loaded_weight
=
convert_pyslice_to_tensor
(
loaded_weight
)
# convert PySafeSlice object to torch.Tensor
if
not
isinstance
(
loaded_weight
,
torch
.
Tensor
):
loaded_weight
=
loaded_weight
[:]
param
[:
loaded_weight
.
shape
[
0
]].
copy_
(
loaded_weight
)
param
[:
loaded_weight
.
shape
[
0
]].
copy_
(
loaded_weight
)
...
@@ -207,10 +240,7 @@ def load_tensor_parallel_weights(
...
@@ -207,10 +240,7 @@ def load_tensor_parallel_weights(
loaded_weight
=
loaded_weight
[:,
start_idx
:
end_idx
]
loaded_weight
=
loaded_weight
[:,
start_idx
:
end_idx
]
break
break
# convert PySafeSlice object to torch.Tensor
loaded_weight
=
convert_pyslice_to_tensor
(
loaded_weight
)
if
not
isinstance
(
loaded_weight
,
torch
.
Tensor
):
loaded_weight
=
loaded_weight
[:]
assert
param
.
shape
==
loaded_weight
.
shape
,
(
assert
param
.
shape
==
loaded_weight
.
shape
,
(
f
"
{
param_name
}
shape mismatch between model and checkpoint: "
f
"
{
param_name
}
shape mismatch between model and checkpoint: "
f
"
{
param
.
shape
}
!=
{
loaded_weight
.
shape
}
"
)
f
"
{
param
.
shape
}
!=
{
loaded_weight
.
shape
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment