Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bf0e382e
Unverified
Commit
bf0e382e
authored
Dec 07, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 07, 2024
Browse files
[Model] Composite weight loading for multimodal Qwen2 (#10944)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
b26b4cd0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
147 additions
and
205 deletions
+147
-205
vllm/config.py
vllm/config.py
+9
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+1
-3
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+3
-7
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+10
-7
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+32
-85
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+83
-96
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+9
-6
No files found.
vllm/config.py
View file @
bf0e382e
...
...
@@ -2472,7 +2472,15 @@ class VllmConfig:
return
quant_config
return
None
def
with_hf_config
(
self
,
hf_config
:
PretrainedConfig
)
->
"VllmConfig"
:
def
with_hf_config
(
self
,
hf_config
:
PretrainedConfig
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
"VllmConfig"
:
if
architectures
is
not
None
:
hf_config
=
copy
.
deepcopy
(
hf_config
)
hf_config
.
architectures
=
architectures
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
model_config
.
hf_config
=
hf_config
...
...
vllm/model_executor/model_loader/loader.py
View file @
bf0e382e
...
...
@@ -101,12 +101,10 @@ def _initialize_model(
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_config
=
vllm_config
.
model_config
model_class
,
_
=
get_model_architecture
(
model_config
,
architectures
=
architectures
)
model_class
,
_
=
get_model_architecture
(
model_config
)
signatures
=
inspect
.
signature
(
model_class
.
__init__
)
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
...
...
vllm/model_executor/model_loader/utils.py
View file @
bf0e382e
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Optional
,
Tuple
,
Type
from
typing
import
Tuple
,
Type
import
torch
from
torch
import
nn
...
...
@@ -20,11 +20,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def
get_model_architecture
(
model_config
:
ModelConfig
,
*
,
architectures
:
Optional
[
list
[
str
]]
=
None
,
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
if
architectures
is
None
:
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
...
...
vllm/model_executor/models/qwen2.py
View file @
bf0e382e
...
...
@@ -444,6 +444,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
...
...
@@ -452,6 +453,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
bf0e382e
...
...
@@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from
functools
import
lru_cache
from
functools
import
cached_property
,
lru_cache
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -34,12 +34,7 @@ from vllm.config import VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
NestedTensors
...
...
@@ -47,15 +42,11 @@ from vllm.multimodal.utils import consecutive_placeholder_ranges
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
merge_multimodal_embeddings
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
# # === Audio Inputs === #
class
Qwen2AudioInputs
(
TypedDict
):
...
...
@@ -281,25 +272,23 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
quant_config
=
quant_config
self
.
language_model
=
Qwen2Model
(
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
),
prefix
=
prefix
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
if
config
.
text_config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
language_model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
text_config
.
vocab_size
,
config
.
text_config
.
hidden_size
,
quant_config
=
quant_config
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
get_sampler
()
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen2ForCausalLM"
],
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
...
...
@@ -414,7 +403,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
)
input_ids
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
...
...
@@ -422,64 +411,22 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
self
.
config
.
text_config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
):
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
or
'audio'
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/qwen2_vl.py
View file @
bf0e382e
...
...
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from
functools
import
partial
from
functools
import
cached_property
,
partial
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
TypedDict
,
Union
)
...
...
@@ -40,7 +40,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
,
parallel_state
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
...
...
@@ -49,15 +49,12 @@ from vllm.model_executor import SamplingMetadata
from
vllm.model_executor.layers.activation
import
QuickGELU
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.multimodal.inputs
import
(
MultiModalData
,
MultiModalDataDict
,
...
...
@@ -69,9 +66,8 @@ from vllm.transformers_utils.config import uses_mrope
from
vllm.transformers_utils.processor
import
cached_get_processor
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
get_vit_attn_backend
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
maybe_prefix
)
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
get_vit_attn_backend
,
init_vllm_registered_model
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
...
...
@@ -506,6 +502,8 @@ class Qwen2VisionTransformer(nn.Module):
mlp_ratio
:
float
=
vision_config
.
mlp_ratio
self
.
spatial_merge_size
=
spatial_merge_size
self
.
num_heads
=
num_heads
self
.
embed_dim
=
embed_dim
self
.
patch_embed
=
Qwen2VisionPatchEmbed
(
patch_size
=
patch_size
,
...
...
@@ -595,6 +593,53 @@ class Qwen2VisionTransformer(nn.Module):
x
=
self
.
merger
(
x
)
return
x
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
name
.
endswith
(
"qkv.weight"
):
visual_num_heads
=
self
.
num_heads
visual_embed_dim
=
self
.
embed_dim
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
,
visual_embed_dim
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
visual_embed_dim
)
elif
name
.
endswith
(
"qkv.bias"
):
visual_num_heads
=
self
.
num_heads
visual_embed_dim
=
self
.
embed_dim
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
# === Vision input helpers === #
...
...
@@ -1082,27 +1127,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen2ForCausalLM"
],
)
if
get_pp_group
().
is_last_rank
:
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
return
get_sampler
()
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
...
...
@@ -1261,7 +1300,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
:
Optional
[
List
[
Tuple
[
NestedTensors
,
str
]]]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_
model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
for
embeddings
,
modality
in
multimodal_embeddings
:
if
modality
==
"image"
:
...
...
@@ -1330,7 +1369,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
)
input_ids
=
None
hidden_states
=
self
.
model
(
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
...
...
@@ -1340,80 +1379,28 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
"visual"
in
name
and
name
.
endswith
(
"qkv.weight"
):
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_embed_dim
=
self
.
config
.
vision_config
.
embed_dim
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
,
visual_embed_dim
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
visual_embed_dim
)
elif
"visual"
in
name
and
name
.
endswith
(
"qkv.bias"
):
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_embed_dim
=
self
.
config
.
vision_config
.
embed_dim
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
try
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
except
KeyError
:
raise
ValueError
(
f
"Unexpected weight:
{
name
}
"
)
from
None
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"lm_head."
:
"language_model.lm_head."
,
"model."
:
"language_model.model."
,
})
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
hf_to_vllm_mapper
)
vllm/model_executor/models/utils.py
View file @
bf0e382e
...
...
@@ -17,7 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.multimodal
import
MultiModalPlaceholderMap
,
NestedTensors
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
,
print_warning_once
logger
=
init_logger
(
__name__
)
...
...
@@ -251,13 +251,16 @@ def init_vllm_registered_model(
"""
from
vllm.model_executor.model_loader.loader
import
_initialize_model
if
hf_config
is
not
None
:
vllm_config
=
vllm_config
.
with_hf_config
(
hf_config
)
if
hf_config
is
None
and
architectures
is
not
None
:
# So that the architectures field is overridden
hf_config
=
vllm_config
.
model_config
.
hf_config
return
_initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
if
hf_config
is
not
None
:
vllm_config
=
vllm_config
.
with_hf_config
(
hf_config
,
architectures
=
architectures
)
return
_initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
@
overload
def
flatten_bn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -592,7 +595,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if
is_flash_attn_2_available
():
selected_backend
=
_Backend
.
FLASH_ATTN
else
:
logger
.
warning
(
print_
warning
_once
(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment