Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
29f49cd6
Unverified
Commit
29f49cd6
authored
Sep 07, 2024
by
Patrick von Platen
Committed by
GitHub
Sep 06, 2024
Browse files
[Model] Allow loading from original Mistral format (#8168)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
23f32229
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
291 additions
and
81 deletions
+291
-81
tests/models/test_mistral.py
tests/models/test_mistral.py
+40
-0
vllm/config.py
vllm/config.py
+33
-29
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+16
-5
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+9
-3
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+11
-10
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+51
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+131
-34
No files found.
tests/models/test_mistral.py
View file @
29f49cd6
...
...
@@ -41,3 +41,43 @@ def test_models(
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
[
1
:])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_mistral_format
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
tokenizer_mode
=
"auto"
,
load_format
=
"safetensors"
,
config_format
=
"hf"
,
)
as
hf_format_model
:
hf_format_outputs
=
hf_format_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
tokenizer_mode
=
"mistral"
,
load_format
=
"mistral"
,
config_format
=
"mistral"
,
)
as
mistral_format_model
:
mistral_format_outputs
=
mistral_format_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
hf_format_outputs
,
outputs_1_lst
=
mistral_format_outputs
,
name_0
=
"hf"
,
name_1
=
"mistral"
,
)
vllm/config.py
View file @
29f49cd6
...
...
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.platforms
import
current_platform
from
vllm.tracing
import
is_otel_available
,
otel_import_error_traceback
from
vllm.transformers_utils.config
import
(
get_config
,
from
vllm.transformers_utils.config
import
(
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
get_hf_text_config
)
from
vllm.utils
import
(
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
GiB_bytes
,
...
...
@@ -121,35 +121,37 @@ class ModelConfig:
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
"""
def
__init__
(
self
,
model
:
str
,
tokenizer
:
str
,
tokenizer_m
ode
:
str
,
trust_remote_code
:
bool
,
dtype
:
Union
[
str
,
torch
.
dtype
]
,
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dic
t
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
tokenizer_revisio
n
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
spec_target_max_model_le
n
:
Optional
[
int
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
max_context
_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
Fals
e
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]
]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
Non
e
,
use_async_output_proc
:
bool
=
Tru
e
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
N
on
e
)
->
None
:
def
__init__
(
self
,
model
:
str
,
tokenizer
:
str
,
tokenizer
_mode
:
str
,
trust_remote_c
ode
:
bool
,
dtype
:
Union
[
str
,
torch
.
dtype
]
,
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
code_
revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
floa
t
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_le
n
:
Optional
[
int
]
=
None
,
spec_target_
max_model_len
:
Optional
[
int
]
=
None
,
quantizatio
n
:
Optional
[
str
]
=
None
,
quantization
_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq
_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Non
e
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
use_async_output_proc
:
bool
=
Tru
e
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
Non
e
,
config_format
:
ConfigFormat
=
C
on
figFormat
.
AUTO
)
->
None
:
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer_mode
=
tokenizer_mode
...
...
@@ -176,7 +178,8 @@ class ModelConfig:
self
.
skip_tokenizer_init
=
skip_tokenizer_init
self
.
hf_config
=
get_config
(
self
.
model
,
trust_remote_code
,
revision
,
code_revision
,
rope_scaling
,
rope_theta
)
code_revision
,
rope_scaling
,
rope_theta
,
config_format
)
self
.
hf_text_config
=
get_hf_text_config
(
self
.
hf_config
)
self
.
hf_image_processor_config
=
get_hf_image_processor_config
(
self
.
model
,
revision
)
...
...
@@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum):
SHARDED_STATE
=
"sharded_state"
GGUF
=
"gguf"
BITSANDBYTES
=
"bitsandbytes"
MISTRAL
=
"mistral"
@
dataclass
...
...
vllm/engine/arg_utils.py
View file @
29f49cd6
...
...
@@ -8,10 +8,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
import
torch
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
Device
Config
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
from
vllm.config
import
(
CacheConfig
,
ConfigFormat
,
Decoding
Config
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
...
...
@@ -65,6 +65,7 @@ class EngineArgs:
trust_remote_code
:
bool
=
False
download_dir
:
Optional
[
str
]
=
None
load_format
:
str
=
'auto'
config_format
:
str
=
'auto'
dtype
:
str
=
'auto'
kv_cache_dtype
:
str
=
'auto'
quantization_param_path
:
Optional
[
str
]
=
None
...
...
@@ -234,6 +235,13 @@ class EngineArgs:
'section for more information.
\n
'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.
\n
'
)
parser
.
add_argument
(
'--config-format'
,
default
=
EngineArgs
.
config_format
,
choices
=
[
f
.
value
for
f
in
ConfigFormat
],
help
=
'The format of the model config to load.
\n\n
'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format '
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
...
...
@@ -813,7 +821,10 @@ class EngineArgs:
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
override_neuron_config
=
self
.
override_neuron_config
)
override_neuron_config
=
self
.
override_neuron_config
,
config_format
=
self
.
config_format
,
)
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
self
.
max_model_len
,
# neuron needs block_size = max_model_len
...
...
vllm/model_executor/model_loader/loader.py
View file @
29f49cd6
...
...
@@ -17,6 +17,7 @@ import torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
...
...
@@ -241,12 +242,17 @@ class DefaultModelLoader(BaseModelLoader):
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
...
...
@@ -284,10 +290,10 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
revision
)
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
)
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
29f49cd6
...
...
@@ -16,7 +16,6 @@ import torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_rank
...
...
@@ -251,6 +250,7 @@ def download_weights_from_hf(
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
index_file
:
str
,
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
...
...
@@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf(
# Download the safetensors index file.
hf_hub_download
(
repo_id
=
model_name_or_path
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
filename
=
index_file
,
cache_dir
=
cache_dir
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
# If file not found on remote or locally, we should not fail since
# only some models will have
SAFE_WEIGHTS_INDEX_NAME
.
# only some models will have
index_file
.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
SAFE_WEIGHTS_INDEX_NAME
)
logger
.
info
(
"No %s found in remote."
,
index_file
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
SAFE_WEIGHTS_INDEX_NAME
)
logger
.
info
(
"No %s found in local cache."
,
index_file
)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the
SAFE_WEIGHTS_INDEX_NAME
to
# So, we use the
index_file
to
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
)
->
List
[
str
]:
hf_folder
:
str
,
index_file
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
SAFE_WEIGHTS_INDEX_NAME
)
index_file_name
=
os
.
path
.
join
(
hf_folder
,
index_file
)
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with
open
(
index_file_name
)
as
index_file
:
weight_map
=
json
.
load
(
index_file
)[
"weight_map"
]
with
open
(
index_file_name
,
"r"
)
as
f
:
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
...
...
vllm/model_executor/models/llama.py
View file @
29f49cd6
...
...
@@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping
=
{
"layers"
:
"model.layers"
,
"attention"
:
"self_attn"
,
"wq"
:
"q_proj"
,
"wk"
:
"k_proj"
,
"wv"
:
"v_proj"
,
"wo"
:
"o_proj"
,
"attention_norm"
:
"input_layernorm"
,
"feed_forward"
:
"mlp"
,
"w1"
:
"gate_proj"
,
"w2"
:
"down_proj"
,
"w3"
:
"up_proj"
,
"ffn_norm"
:
"post_attention_layernorm"
,
"tok_embeddings"
:
"model.embed_tokens"
,
"output"
:
"lm_head"
,
"norm"
:
"model.norm"
}
def
__init__
(
self
,
...
...
@@ -472,6 +491,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
name
,
loaded_weight
=
self
.
maybe_remap_mistral
(
name
,
loaded_weight
)
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
...
...
@@ -549,3 +570,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def
maybe_remap_mistral
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
)
->
Tuple
[
str
,
torch
.
Tensor
]:
def
permute
(
w
,
n_heads
):
attn_in
=
self
.
config
.
head_dim
*
n_heads
attn_out
=
self
.
config
.
hidden_size
return
w
.
view
(
n_heads
,
attn_in
//
n_heads
//
2
,
2
,
attn_out
).
transpose
(
1
,
2
).
reshape
(
attn_in
,
attn_out
)
mapping
=
self
.
mistral_mapping
modules
=
name
.
split
(
"."
)
# rotary embeds should be sliced
if
"wk"
in
modules
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_key_value_heads
)
elif
"wq"
in
modules
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_attention_heads
)
for
item
in
modules
:
if
item
in
mapping
and
mapping
[
item
]
not
in
name
:
name
=
name
.
replace
(
item
,
mapping
[
item
])
return
name
,
loaded_weight
vllm/transformers_utils/config.py
View file @
29f49cd6
import
contextlib
import
enum
import
json
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
from
huggingface_hub
import
file_exists
,
hf_hub_download
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers.models.auto.image_processing_auto
import
(
get_image_processor_config
)
from
transformers.models.auto.modeling_auto
import
(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
from
transformers.utils
import
CONFIG_NAME
as
HF_CONFIG_NAME
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
...
...
@@ -27,6 +31,8 @@ if VLLM_USE_MODELSCOPE:
else
:
from
transformers
import
AutoConfig
MISTRAL_CONFIG_NAME
=
"params.json"
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
...
...
@@ -53,6 +59,20 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig
.
register
(
name
,
cls
)
class
ConfigFormat
(
str
,
enum
.
Enum
):
AUTO
=
"auto"
HF
=
"hf"
MISTRAL
=
"mistral"
def
file_or_path_exists
(
model
:
Union
[
str
,
Path
],
config_name
,
revision
,
token
)
->
bool
:
if
Path
(
model
).
exists
():
return
(
Path
(
model
)
/
config_name
).
is_file
()
return
file_exists
(
model
,
HF_CONFIG_NAME
,
revision
=
revision
,
token
=
token
)
def
get_config
(
model
:
Union
[
str
,
Path
],
trust_remote_code
:
bool
,
...
...
@@ -60,45 +80,68 @@ def get_config(
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
config_format
:
ConfigFormat
=
ConfigFormat
.
AUTO
,
**
kwargs
,
)
->
PretrainedConfig
:
# Separate model folder from file path for GGUF models
is_gguf
=
check_gguf_file
(
model
)
if
is_gguf
:
kwargs
[
"gguf_file"
]
=
Path
(
model
).
name
model
=
Path
(
model
).
parent
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
model
,
revision
=
revision
,
code_revision
=
code_revision
,
**
kwargs
)
if
config_format
==
ConfigFormat
.
AUTO
:
if
is_gguf
or
file_or_path_exists
(
model
,
HF_CONFIG_NAME
,
revision
=
revision
,
token
=
kwargs
.
get
(
"token"
)):
config_format
=
ConfigFormat
.
HF
elif
file_or_path_exists
(
model
,
MISTRAL_CONFIG_NAME
,
revision
=
revision
,
token
=
kwargs
.
get
(
"token"
)):
config_format
=
ConfigFormat
.
MISTRAL
else
:
raise
ValueError
(
f
"No supported config format found in
{
model
}
"
)
if
config_format
==
ConfigFormat
.
HF
:
config_dict
,
_
=
PretrainedConfig
.
get_config_dict
(
model
,
revision
=
revision
,
code_revision
=
code_revision
,
**
kwargs
)
# Use custom model class if it's in our registry
model_type
=
config_dict
.
get
(
"model_type"
)
if
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
code_revision
=
code_revision
)
else
:
try
:
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
code_revision
=
code_revision
,
**
kwargs
,
)
except
ValueError
as
e
:
if
(
not
trust_remote_code
and
"requires you to execute the configuration file"
in
str
(
e
)):
err_msg
=
(
"Failed to load the model config. If the model "
"is a custom model not yet available in the "
"HuggingFace transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
# Use custom model class if it's in our registry
model_type
=
config_dict
.
get
(
"model_type"
)
if
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
code_revision
=
code_revision
)
elif
config_format
==
ConfigFormat
.
MISTRAL
:
config
=
load_params_config
(
model
,
revision
)
else
:
try
:
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
code_revision
=
code_revision
,
**
kwargs
)
except
ValueError
as
e
:
if
(
not
trust_remote_code
and
"requires you to execute the configuration file"
in
str
(
e
)):
err_msg
=
(
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
raise
ValueError
(
f
"Unsupported config format:
{
config_format
}
"
)
# Special architecture mapping check for GGUF models
if
is_gguf
:
...
...
@@ -108,16 +151,70 @@ def get_config(
model_type
=
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[
config
.
model_type
]
config
.
update
({
"architectures"
:
[
model_type
]})
for
key
,
value
in
[(
"rope_scaling"
,
rope_scaling
),
(
"rope_theta"
,
rope_theta
)]:
for
key
,
value
in
[
(
"rope_scaling"
,
rope_scaling
),
(
"rope_theta"
,
rope_theta
),
]:
if
value
is
not
None
:
logger
.
info
(
"Updating %s from %r to %r"
,
key
,
getattr
(
config
,
key
,
None
),
value
)
logger
.
info
(
"Updating %s from %r to %r"
,
key
,
getattr
(
config
,
key
,
None
),
value
,
)
config
.
update
({
key
:
value
})
return
config
def
load_params_config
(
model
,
revision
)
->
PretrainedConfig
:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_file_name
=
"params.json"
config_path
=
Path
(
model
)
/
config_file_name
if
not
config_path
.
is_file
():
config_path
=
Path
(
hf_hub_download
(
model
,
config_file_name
,
revision
=
revision
))
with
open
(
config_path
,
"r"
)
as
file
:
config_dict
=
json
.
load
(
file
)
config_mapping
=
{
"dim"
:
"hidden_size"
,
"norm_eps"
:
"rms_norm_eps"
,
"n_kv_heads"
:
"num_key_value_heads"
,
"n_layers"
:
"num_hidden_layers"
,
"n_heads"
:
"num_attention_heads"
,
"hidden_dim"
:
"intermediate_size"
,
}
def
recurse_elems
(
elem
:
Any
):
if
isinstance
(
elem
,
dict
):
config_dict
=
{}
for
key
,
value
in
elem
.
items
():
key
=
config_mapping
.
get
(
key
,
key
)
config_dict
[
key
]
=
recurse_elems
(
value
)
return
PretrainedConfig
(
**
config_dict
)
else
:
return
elem
config_dict
[
"model_type"
]
=
config_dict
.
get
(
"model_type"
,
"transformer"
)
config_dict
[
"hidden_act"
]
=
config_dict
.
get
(
"activation"
,
"silu"
)
config_dict
[
"tie_word_embeddings"
]
=
config_dict
.
get
(
"tie_embeddings"
,
False
)
if
config_dict
[
"model_type"
]
==
"transformer"
:
if
"moe"
in
config_dict
:
config_dict
[
"architectures"
]
=
[
"MixtralForCausalLM"
]
else
:
config_dict
[
"architectures"
]
=
[
"MistralForCausalLM"
]
return
recurse_elems
(
config_dict
)
def
get_hf_image_processor_config
(
model
:
Union
[
str
,
Path
],
revision
:
Optional
[
str
]
=
None
,
...
...
@@ -134,7 +231,7 @@ def get_hf_image_processor_config(
def
get_hf_text_config
(
config
:
PretrainedConfig
):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
No op for pure text models.
"""
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment