Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
36bf8150
Unverified
Commit
36bf8150
authored
Sep 08, 2024
by
Isotr0py
Committed by
GitHub
Sep 07, 2024
Browse files
[Model][VLM] Decouple weight loading logic for `Paligemma` (#8269)
parent
e8071259
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
81 deletions
+54
-81
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+35
-77
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+19
-4
No files found.
vllm/model_executor/models/paligemma.py
View file @
36bf8150
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -13,7 +14,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.gemma
import
Gemma
Model
from
vllm.model_executor.models.gemma
import
Gemma
ForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
cached_get_tokenizer
...
...
@@ -22,14 +23,10 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
merge_multimodal_embeddings
from
.utils
import
filter_weights
,
merge_multimodal_embeddings
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.model"
:
"language_model"
,
}
class
PaliGemmaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
...
...
@@ -151,8 +148,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
projection_dim
=
config
.
vision_config
.
projection_dim
)
self
.
quant_config
=
quant_config
self
.
language_model
=
Gemma
Model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
language_model
=
Gemma
ForCausalLM
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
...
...
@@ -252,7 +249,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
vision_embeddings
=
vision_embeddings
*
(
self
.
config
.
hidden_size
**
-
0.5
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
...
...
@@ -262,87 +260,47 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
# Copied from vllm/model_executor/models/gemma.py
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
language_model
.
embed_tokens
,
hidden_states
,
sampling_metadata
)
return
logits
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Copied from vllm/model_executor/models/gemma.py
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
# Adapted from vllm/model_executor/models/gemma.py
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
weights
:
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
not
in
name
or
self
.
vision_tower
.
shard_weight
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
use_default_weight_loading
=
True
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
logger
.
warning
(
"Some weights are not initialized from checkpoints: %s"
,
unloaded_params
)
# prepare weight iterators for components
vit_weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
# load vision tower
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_weights
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
vllm/model_executor/models/siglip.py
View file @
36bf8150
...
...
@@ -529,6 +529,12 @@ class SiglipVisionModel(nn.Module):
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
params_dict
=
dict
(
self
.
named_parameters
())
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
...
...
@@ -544,7 +550,16 @@ class SiglipVisionModel(nn.Module):
if
layer_idx
>=
layer_count
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment