Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
29f49cd6
Unverified
Commit
29f49cd6
authored
Sep 07, 2024
by
Patrick von Platen
Committed by
GitHub
Sep 06, 2024
Browse files
[Model] Allow loading from original Mistral format (#8168)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
23f32229
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
291 additions
and
81 deletions
+291
-81
tests/models/test_mistral.py
tests/models/test_mistral.py
+40
-0
vllm/config.py
vllm/config.py
+33
-29
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+16
-5
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+9
-3
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+11
-10
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+51
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+131
-34
No files found.
tests/models/test_mistral.py
View file @
29f49cd6
...
@@ -41,3 +41,43 @@ def test_models(
...
@@ -41,3 +41,43 @@ def test_models(
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm"
,
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
[
1
:])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_mistral_format
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
tokenizer_mode
=
"auto"
,
load_format
=
"safetensors"
,
config_format
=
"hf"
,
)
as
hf_format_model
:
hf_format_outputs
=
hf_format_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
tokenizer_mode
=
"mistral"
,
load_format
=
"mistral"
,
config_format
=
"mistral"
,
)
as
mistral_format_model
:
mistral_format_outputs
=
mistral_format_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
hf_format_outputs
,
outputs_1_lst
=
mistral_format_outputs
,
name_0
=
"hf"
,
name_1
=
"mistral"
,
)
vllm/config.py
View file @
29f49cd6
...
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
...
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.tracing
import
is_otel_available
,
otel_import_error_traceback
from
vllm.tracing
import
is_otel_available
,
otel_import_error_traceback
from
vllm.transformers_utils.config
import
(
get_config
,
from
vllm.transformers_utils.config
import
(
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
get_hf_image_processor_config
,
get_hf_text_config
)
get_hf_text_config
)
from
vllm.utils
import
(
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
GiB_bytes
,
from
vllm.utils
import
(
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
GiB_bytes
,
...
@@ -121,35 +121,37 @@ class ModelConfig:
...
@@ -121,35 +121,37 @@ class ModelConfig:
override default neuron config that are specific to Neuron devices,
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
tokenizer
:
str
,
tokenizer
:
str
,
tokenizer
_mode
:
str
,
tokenizer_m
ode
:
str
,
trust_remote_c
ode
:
bool
,
trust_remote_code
:
bool
,
dtype
:
Union
[
str
,
torch
.
dtype
]
,
dtype
:
Union
[
str
,
torch
.
dtype
]
,
seed
:
int
,
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
code_
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_scaling
:
Optional
[
dic
t
]
=
None
,
rope_theta
:
Optional
[
floa
t
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revisio
n
:
Optional
[
str
]
=
None
,
max_model_le
n
:
Optional
[
int
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
spec_target_
max_model_len
:
Optional
[
int
]
=
None
,
spec_target_max_model_le
n
:
Optional
[
int
]
=
None
,
quantizatio
n
:
Optional
[
str
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization
_param_path
:
Optional
[
str
]
=
None
,
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_context
_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq
_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
20
,
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
Fals
e
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Non
e
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]
]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
Non
e
,
use_async_output_proc
:
bool
=
Tru
e
,
use_async_output_proc
:
bool
=
Tru
e
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
Non
e
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
N
on
e
)
->
None
:
config_format
:
ConfigFormat
=
C
on
figFormat
.
AUTO
)
->
None
:
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
...
@@ -176,7 +178,8 @@ class ModelConfig:
...
@@ -176,7 +178,8 @@ class ModelConfig:
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
,
rope_scaling
,
rope_theta
)
code_revision
,
rope_scaling
,
rope_theta
,
config_format
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_image_processor_config
=
get_hf_image_processor_config
(
self
.
hf_image_processor_config
=
get_hf_image_processor_config
(
self
.
model
,
revision
)
self
.
model
,
revision
)
...
@@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
...
@@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
SHARDED_STATE
=
"sharded_state"
SHARDED_STATE
=
"sharded_state"
GGUF
=
"gguf"
GGUF
=
"gguf"
BITSANDBYTES
=
"bitsandbytes"
BITSANDBYTES
=
"bitsandbytes"
MISTRAL
=
"mistral"
@
dataclass
@
dataclass
...
...
vllm/engine/arg_utils.py
View file @
29f49cd6
...
@@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
...
@@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
Device
Config
,
from
vllm.config
import
(
CacheConfig
,
ConfigFormat
,
Decoding
Config
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
)
SpeculativeConfig
,
TokenizerPoolConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -65,6 +65,7 @@ class EngineArgs:
...
@@ -65,6 +65,7 @@ class EngineArgs:
trust_remote_code
:
bool
=
False
trust_remote_code
:
bool
=
False
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
load_format
:
str
=
'auto'
load_format
:
str
=
'auto'
config_format
:
str
=
'auto'
dtype
:
str
=
'auto'
dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
quantization_param_path
:
Optional
[
str
]
=
None
quantization_param_path
:
Optional
[
str
]
=
None
...
@@ -234,6 +235,13 @@ class EngineArgs:
...
@@ -234,6 +235,13 @@ class EngineArgs:
'section for more information.
\n
'
'section for more information.
\n
'
'* "bitsandbytes" will load the weights using bitsandbytes '
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.
\n
'
)
'quantization.
\n
'
)
parser
.
add_argument
(
'--config-format'
,
default
=
EngineArgs
.
config_format
,
choices
=
[
f
.
value
for
f
in
ConfigFormat
],
help
=
'The format of the model config to load.
\n\n
'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format '
)
parser
.
add_argument
(
parser
.
add_argument
(
'--dtype'
,
'--dtype'
,
type
=
str
,
type
=
str
,
...
@@ -813,7 +821,10 @@ class EngineArgs:
...
@@ -813,7 +821,10 @@ class EngineArgs:
served_model_name
=
self
.
served_model_name
,
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
override_neuron_config
=
self
.
override_neuron_config
)
override_neuron_config
=
self
.
override_neuron_config
,
config_format
=
self
.
config_format
,
)
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
self
.
max_model_len
,
# neuron needs block_size = max_model_len
self
.
max_model_len
,
# neuron needs block_size = max_model_len
...
...
vllm/model_executor/model_loader/loader.py
View file @
29f49cd6
...
@@ -17,6 +17,7 @@ import torch
...
@@ -17,6 +17,7 @@ import torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
...
@@ -241,12 +242,17 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -241,12 +242,17 @@ class DefaultModelLoader(BaseModelLoader):
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
use_safetensors
=
True
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
elif
load_format
==
LoadFormat
.
NPCACHE
:
...
@@ -284,10 +290,10 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -284,10 +290,10 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index.
# any files not found in the index.
if
not
is_local
:
if
not
is_local
:
download_safetensors_index_file_from_hf
(
download_safetensors_index_file_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
model_name_or_path
,
index_file
,
revision
)
self
.
load_config
.
download_dir
,
revision
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
)
hf_weights_files
,
hf_folder
,
index_file
)
else
:
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
hf_weights_files
)
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
29f49cd6
...
@@ -16,7 +16,6 @@ import torch
...
@@ -16,7 +16,6 @@ import torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
get_tensor_model_parallel_rank
...
@@ -251,6 +250,7 @@ def download_weights_from_hf(
...
@@ -251,6 +250,7 @@ def download_weights_from_hf(
def
download_safetensors_index_file_from_hf
(
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
model_name_or_path
:
str
,
index_file
:
str
,
cache_dir
:
Optional
[
str
],
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
...
@@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
# Download the safetensors index file.
# Download the safetensors index file.
hf_hub_download
(
hf_hub_download
(
repo_id
=
model_name_or_path
,
repo_id
=
model_name_or_path
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
filename
=
index_file
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
revision
=
revision
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
)
# If file not found on remote or locally, we should not fail since
# If file not found on remote or locally, we should not fail since
# only some models will have
SAFE_WEIGHTS_INDEX_NAME
.
# only some models will have
index_file
.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
SAFE_WEIGHTS_INDEX_NAME
)
logger
.
info
(
"No %s found in remote."
,
index_file
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
SAFE_WEIGHTS_INDEX_NAME
)
logger
.
info
(
"No %s found in local cache."
,
index_file
)
# For models like Mistral-7B-v0.3, there are both sharded
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# Passing both of these to the weight loader functionality breaks.
# So, we use the
SAFE_WEIGHTS_INDEX_NAME
to
# So, we use the
index_file
to
# look up which safetensors files should be used.
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
)
->
List
[
str
]:
hf_folder
:
str
,
index_file
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
SAFE_WEIGHTS_INDEX_NAME
)
index_file_name
=
os
.
path
.
join
(
hf_folder
,
index_file
)
if
not
os
.
path
.
isfile
(
index_file_name
):
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
# to identify weights that we should use.
with
open
(
index_file_name
)
as
index_file
:
with
open
(
index_file_name
,
"r"
)
as
f
:
weight_map
=
json
.
load
(
index_file
)[
"weight_map"
]
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
weight_files_in_index
=
set
()
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
weight_files_in_index
.
add
(
...
...
vllm/model_executor/models/llama.py
View file @
29f49cd6
...
@@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping
=
{
"layers"
:
"model.layers"
,
"attention"
:
"self_attn"
,
"wq"
:
"q_proj"
,
"wk"
:
"k_proj"
,
"wv"
:
"v_proj"
,
"wo"
:
"o_proj"
,
"attention_norm"
:
"input_layernorm"
,
"feed_forward"
:
"mlp"
,
"w1"
:
"gate_proj"
,
"w2"
:
"down_proj"
,
"w3"
:
"up_proj"
,
"ffn_norm"
:
"post_attention_layernorm"
,
"tok_embeddings"
:
"model.embed_tokens"
,
"output"
:
"lm_head"
,
"norm"
:
"model.norm"
}
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -472,6 +491,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -472,6 +491,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
name
,
loaded_weight
=
self
.
maybe_remap_mistral
(
name
,
loaded_weight
)
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
(
"rotary_emb.cos_cached"
in
name
if
(
"rotary_emb.cos_cached"
in
name
...
@@ -549,3 +570,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -549,3 +570,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
else
:
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
"factor attribute!"
)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def
maybe_remap_mistral
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
)
->
Tuple
[
str
,
torch
.
Tensor
]:
def
permute
(
w
,
n_heads
):
attn_in
=
self
.
config
.
head_dim
*
n_heads
attn_out
=
self
.
config
.
hidden_size
return
w
.
view
(
n_heads
,
attn_in
//
n_heads
//
2
,
2
,
attn_out
).
transpose
(
1
,
2
).
reshape
(
attn_in
,
attn_out
)
mapping
=
self
.
mistral_mapping
modules
=
name
.
split
(
"."
)
# rotary embeds should be sliced
if
"wk"
in
modules
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_key_value_heads
)
elif
"wq"
in
modules
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_attention_heads
)
for
item
in
modules
:
if
item
in
mapping
and
mapping
[
item
]
not
in
name
:
name
=
name
.
replace
(
item
,
mapping
[
item
])
return
name
,
loaded_weight
vllm/transformers_utils/config.py
View file @
29f49cd6
import
contextlib
import
contextlib
import
enum
import
json
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
from
huggingface_hub
import
file_exists
,
hf_hub_download
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers.models.auto.image_processing_auto
import
(
from
transformers.models.auto.image_processing_auto
import
(
get_image_processor_config
)
get_image_processor_config
)
from
transformers.models.auto.modeling_auto
import
(
from
transformers.models.auto.modeling_auto
import
(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
from
transformers.utils
import
CONFIG_NAME
as
HF_CONFIG_NAME
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -27,6 +31,8 @@ if VLLM_USE_MODELSCOPE:
...
@@ -27,6 +31,8 @@ if VLLM_USE_MODELSCOPE:
else
:
else
:
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
MISTRAL_CONFIG_NAME
=
"params.json"
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
...
@@ -53,6 +59,20 @@ for name, cls in _CONFIG_REGISTRY.items():
...
@@ -53,6 +59,20 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig
.
register
(
name
,
cls
)
AutoConfig
.
register
(
name
,
cls
)
class
ConfigFormat
(
str
,
enum
.
Enum
):
AUTO
=
"auto"
HF
=
"hf"
MISTRAL
=
"mistral"
def
file_or_path_exists
(
model
:
Union
[
str
,
Path
],
config_name
,
revision
,
token
)
->
bool
:
if
Path
(
model
).
exists
():
return
(
Path
(
model
)
/
config_name
).
is_file
()
return
file_exists
(
model
,
HF_CONFIG_NAME
,
revision
=
revision
,
token
=
token
)
def
get_config
(
def
get_config
(
model
:
Union
[
str
,
Path
],
model
:
Union
[
str
,
Path
],
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
...
@@ -60,45 +80,68 @@ def get_config(
...
@@ -60,45 +80,68 @@ def get_config(
code_revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
,
**
kwargs
,
**
kwargs
,
)
->
PretrainedConfig
:
)
->
PretrainedConfig
:
# Separate model folder from file path for GGUF models
# Separate model folder from file path for GGUF models
is_gguf
=
check_gguf_file
(
model
)
is_gguf
=
check_gguf_file
(
model
)
if
is_gguf
:
if
is_gguf
:
kwargs
[
"gguf_file"
]
=
Path
(
model
).
name
kwargs
[
"gguf_file"
]
=
Path
(
model
).
name
model
=
Path
(
model
).
parent
model
=
Path
(
model
).
parent
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
if
config_format
==
ConfigFormat
.
AUTO
:
model
,
revision
=
revision
,
code_revision
=
code_revision
,
**
kwargs
)
if
is_gguf
or
file_or_path_exists
(
model
,
HF_CONFIG_NAME
,
revision
=
revision
,
token
=
kwargs
.
get
(
"token"
)):
config_format
=
ConfigFormat
.
HF
elif
file_or_path_exists
(
model
,
MISTRAL_CONFIG_NAME
,
revision
=
revision
,
token
=
kwargs
.
get
(
"token"
)):
config_format
=
ConfigFormat
.
MISTRAL
else
:
raise
ValueError
(
f
"No supported config format found in
{
model
}
"
)
if
config_format
==
ConfigFormat
.
HF
:
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
model
,
revision
=
revision
,
code_revision
=
code_revision
,
**
kwargs
)
# Use custom model class if it's in our registry
model_type
=
config_dict
.
get
(
"model_type"
)
if
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
code_revision
=
code_revision
)
else
:
try
:
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
code_revision
=
code_revision
,
**
kwargs
,
)
except
ValueError
as
e
:
if
(
not
trust_remote_code
and
"requires you to execute the configuration file"
in
str
(
e
)):
err_msg
=
(
"Failed to load the model config. If the model "
"is a custom model not yet available in the "
"HuggingFace transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
# Use custom model class if it's in our registry
elif
config_format
==
ConfigFormat
.
MISTRAL
:
model_type
=
config_dict
.
get
(
"model_type"
)
config
=
load_params_config
(
model
,
revision
)
if
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
code_revision
=
code_revision
)
else
:
else
:
try
:
raise
ValueError
(
f
"Unsupported config format:
{
config_format
}
"
)
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
code_revision
=
code_revision
,
**
kwargs
)
except
ValueError
as
e
:
if
(
not
trust_remote_code
and
"requires you to execute the configuration file"
in
str
(
e
)):
err_msg
=
(
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
# Special architecture mapping check for GGUF models
# Special architecture mapping check for GGUF models
if
is_gguf
:
if
is_gguf
:
...
@@ -108,16 +151,70 @@ def get_config(
...
@@ -108,16 +151,70 @@ def get_config(
model_type
=
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[
config
.
model_type
]
model_type
=
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[
config
.
model_type
]
config
.
update
({
"architectures"
:
[
model_type
]})
config
.
update
({
"architectures"
:
[
model_type
]})
for
key
,
value
in
[(
"rope_scaling"
,
rope_scaling
),
for
key
,
value
in
[
(
"rope_theta"
,
rope_theta
)]:
(
"rope_scaling"
,
rope_scaling
),
(
"rope_theta"
,
rope_theta
),
]:
if
value
is
not
None
:
if
value
is
not
None
:
logger
.
info
(
"Updating %s from %r to %r"
,
key
,
logger
.
info
(
getattr
(
config
,
key
,
None
),
value
)
"Updating %s from %r to %r"
,
key
,
getattr
(
config
,
key
,
None
),
value
,
)
config
.
update
({
key
:
value
})
config
.
update
({
key
:
value
})
return
config
return
config
def
load_params_config
(
model
,
revision
)
->
PretrainedConfig
:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_file_name
=
"params.json"
config_path
=
Path
(
model
)
/
config_file_name
if
not
config_path
.
is_file
():
config_path
=
Path
(
hf_hub_download
(
model
,
config_file_name
,
revision
=
revision
))
with
open
(
config_path
,
"r"
)
as
file
:
config_dict
=
json
.
load
(
file
)
config_mapping
=
{
"dim"
:
"hidden_size"
,
"norm_eps"
:
"rms_norm_eps"
,
"n_kv_heads"
:
"num_key_value_heads"
,
"n_layers"
:
"num_hidden_layers"
,
"n_heads"
:
"num_attention_heads"
,
"hidden_dim"
:
"intermediate_size"
,
}
def
recurse_elems
(
elem
:
Any
):
if
isinstance
(
elem
,
dict
):
config_dict
=
{}
for
key
,
value
in
elem
.
items
():
key
=
config_mapping
.
get
(
key
,
key
)
config_dict
[
key
]
=
recurse_elems
(
value
)
return
PretrainedConfig
(
**
config_dict
)
else
:
return
elem
config_dict
[
"model_type"
]
=
config_dict
.
get
(
"model_type"
,
"transformer"
)
config_dict
[
"hidden_act"
]
=
config_dict
.
get
(
"activation"
,
"silu"
)
config_dict
[
"tie_word_embeddings"
]
=
config_dict
.
get
(
"tie_embeddings"
,
False
)
if
config_dict
[
"model_type"
]
==
"transformer"
:
if
"moe"
in
config_dict
:
config_dict
[
"architectures"
]
=
[
"MixtralForCausalLM"
]
else
:
config_dict
[
"architectures"
]
=
[
"MistralForCausalLM"
]
return
recurse_elems
(
config_dict
)
def
get_hf_image_processor_config
(
def
get_hf_image_processor_config
(
model
:
Union
[
str
,
Path
],
model
:
Union
[
str
,
Path
],
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
...
@@ -134,7 +231,7 @@ def get_hf_image_processor_config(
...
@@ -134,7 +231,7 @@ def get_hf_image_processor_config(
def
get_hf_text_config
(
config
:
PretrainedConfig
):
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
No op for pure text models.
"""
"""
if
hasattr
(
config
,
"text_config"
):
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
# The code operates under the assumption that text_config should have
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment