Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
452a7c9f
Unverified
Commit
452a7c9f
authored
Nov 26, 2025
by
Roger Wang
Committed by
GitHub
Nov 26, 2025
Browse files
[Misc] Allow LM only loading for Pixtral (#29451)
Signed-off-by:
Roger Wang
<
hey@rogerw.io
>
parent
d9d342d2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
22 deletions
+51
-22
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+51
-22
No files found.
vllm/model_executor/models/pixtral.py
View file @
452a7c9f
...
...
@@ -400,21 +400,30 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
vision_encoder
=
VisionTransformer
(
self
.
vision_args
)
if
self
.
vision_args
.
add_pre_mm_projector_layer_norm
:
self
.
pre_mm_projector_norm
=
RMSNorm
(
self
.
vision_args
.
hidden_size
,
eps
=
1e-5
)
if
self
.
vision_args
.
mm_projector_id
==
PATCH_MERGE
:
self
.
patch_merger
=
PatchMerger
(
vision_encoder_dim
=
self
.
vision_args
.
hidden_size
,
spatial_merge_size
=
self
.
vision_args
.
spatial_merge_size
,
use_mlp_bias
=
False
,
if
multimodal_config
.
get_limit_per_prompt
(
"image"
):
self
.
vision_encoder
=
VisionTransformer
(
self
.
vision_args
)
self
.
pre_mm_projector_norm
=
(
RMSNorm
(
self
.
vision_args
.
hidden_size
,
eps
=
1e-5
)
if
self
.
vision_args
.
add_pre_mm_projector_layer_norm
else
None
)
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
self
.
vision_args
,
dim
=
config
.
text_config
.
hidden_size
)
self
.
patch_merger
=
(
PatchMerger
(
vision_encoder_dim
=
self
.
vision_args
.
hidden_size
,
spatial_merge_size
=
self
.
vision_args
.
spatial_merge_size
,
use_mlp_bias
=
False
,
)
if
self
.
vision_args
.
mm_projector_id
==
PATCH_MERGE
else
None
)
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
self
.
vision_args
,
dim
=
config
.
text_config
.
hidden_size
)
else
:
self
.
vision_encoder
=
None
self
.
pre_mm_projector_norm
=
None
self
.
patch_merger
=
None
self
.
vision_language_adapter
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
...
...
@@ -436,13 +445,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self
,
image_input
:
PixtralImagePixelInputs
,
)
->
tuple
[
torch
.
Tensor
,
...]:
assert
(
self
.
vision_encoder
is
not
None
and
self
.
vision_language_adapter
is
not
None
)
images
=
image_input
[
"images"
]
image_features
=
self
.
vision_encoder
(
images
)
feature_sizes
=
[
image_feature
.
shape
[
0
]
for
image_feature
in
image_features
]
image_features
=
torch
.
cat
(
image_features
)
if
self
.
vision_args
.
add_
pre_mm_projector_
layer_norm
:
if
self
.
pre_mm_projector_
norm
is
not
None
:
image_features
=
self
.
pre_mm_projector_norm
(
image_features
)
if
self
.
vision_args
.
mm_projector_id
==
PATCH_MERGE
:
if
self
.
patch_merger
is
not
None
:
patch_size
=
self
.
vision_args
.
patch_size
spatial_merge_size_square
=
self
.
vision_args
.
spatial_merge_size
**
2
img_patch_dims
=
[
...
...
@@ -508,41 +521,57 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return
weight
[
0
].
startswith
(
"pre_mm_projector_norm"
)
# Get references to parameters for direct loading
vision_encoder_dict
=
dict
(
self
.
vision_encoder
.
named_parameters
())
vision_encoder_dict
=
(
dict
(
self
.
vision_encoder
.
named_parameters
())
if
self
.
vision_encoder
is
not
None
else
{}
)
patch_merger_dict
=
(
dict
(
self
.
patch_merger
.
named_parameters
())
if
self
.
vision_args
.
mm_projector_id
==
PATCH_MERGE
else
dict
()
if
self
.
patch_merger
is
not
None
else
{}
)
pre_mm_projector_norm_dict
=
(
dict
(
self
.
pre_mm_projector_norm
.
named_parameters
())
if
self
.
vision_args
.
add_pre_mm_projector_layer_norm
else
dict
()
if
self
.
pre_mm_projector_norm
is
not
None
else
{}
)
vision_lang_adapter_dict
=
(
dict
(
self
.
vision_language_adapter
.
named_parameters
())
if
self
.
vision_language_adapter
is
not
None
else
{}
)
vision_lang_adapter_dict
=
dict
(
self
.
vision_language_adapter
.
named_parameters
())
def
llm_weights_generator
():
# Single pass over weights
for
name
,
w
in
weights
:
if
is_vision_encoder_weights
((
name
,
w
)):
if
self
.
vision_encoder
is
None
:
continue
# Load vision encoder weights directly
trimmed_name
=
"."
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_encoder_dict
[
trimmed_name
]
with
torch
.
no_grad
():
default_weight_loader
(
param
,
w
)
elif
is_patch_merger
((
name
,
w
)):
if
self
.
patch_merger
is
None
:
continue
# Load vision patch merger weights directly
trimmed_name
=
"."
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
patch_merger_dict
[
trimmed_name
]
with
torch
.
no_grad
():
default_weight_loader
(
param
,
w
)
elif
is_pre_mm_projector_norm
((
name
,
w
)):
if
self
.
pre_mm_projector_norm
is
None
:
continue
# Load vision pre_mm_projector_norm weights directly
trimmed_name
=
"."
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
pre_mm_projector_norm_dict
[
trimmed_name
]
with
torch
.
no_grad
():
default_weight_loader
(
param
,
w
)
elif
is_vision_lang_adapter_weights
((
name
,
w
)):
if
self
.
vision_language_adapter
is
None
:
continue
# Load vision-language adapter weights directly
trimmed_name
=
"."
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_lang_adapter_dict
[
trimmed_name
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment