Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc192a2b
Unverified
Commit
bc192a2b
authored
Dec 10, 2024
by
Patrick von Platen
Committed by
GitHub
Dec 10, 2024
Browse files
[Pixtral] Improve loading (#11040)
parent
980ad394
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
31 deletions
+25
-31
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+25
-31
No files found.
vllm/model_executor/models/pixtral.py
View file @
bc192a2b
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
functools
import
cached_property
from
functools
import
cached_property
from
itertools
import
tee
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
import
numpy
...
@@ -359,38 +358,33 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -359,38 +358,33 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def
is_vision_lang_adapter_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
def
is_vision_lang_adapter_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
return
weight
[
0
].
startswith
(
"vision_language_adapter"
)
return
weight
[
0
].
startswith
(
"vision_language_adapter"
)
def
is_vision_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
# Get references to parameters for direct loading
return
is_vision_encoder_weights
(
weight
)
or
is_vision_lang_adapter_weights
(
weight
)
llm_weights
,
vision_encoder_weights
,
vision_lang_adapter_weights
=
tee
(
weights
,
3
)
# llm
llm_weights
=
filter
(
lambda
x
:
not
is_vision_weights
(
x
),
llm_weights
)
self
.
language_model
.
load_weights
(
llm_weights
)
# vision encoder
vision_encoder_weights
=
filter
(
is_vision_encoder_weights
,
vision_encoder_weights
)
vision_encoder_dict
=
dict
(
self
.
vision_encoder
.
named_parameters
())
vision_encoder_dict
=
dict
(
self
.
vision_encoder
.
named_parameters
())
for
name
,
loaded_weight
in
vision_encoder_weights
:
vision_lang_adapter_dict
=
dict
(
# cut 'vision_encoder.'
self
.
vision_language_adapter
.
named_parameters
())
name
=
'.'
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_encoder_dict
[
name
]
default_weight_loader
(
param
,
loaded_weight
)
def
llm_weights_generator
():
# Single pass over weights
for
name
,
w
in
weights
:
if
is_vision_encoder_weights
((
name
,
w
)):
# Load vision encoder weights directly
trimmed_name
=
'.'
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_encoder_dict
[
trimmed_name
]
with
torch
.
no_grad
():
default_weight_loader
(
param
,
w
)
elif
is_vision_lang_adapter_weights
((
name
,
w
)):
# Load vision-language adapter weights directly
trimmed_name
=
'.'
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_lang_adapter_dict
[
trimmed_name
]
with
torch
.
no_grad
():
default_weight_loader
(
param
,
w
)
else
:
# LLM weights: yield them to be loaded
# by language_model.load_weights
yield
(
name
,
w
)
# adapter
# Now we call the language model load with the generator
vision_lang_adapter_weights
=
filter
(
is_vision_lang_adapter_weights
,
self
.
language_model
.
load_weights
(
llm_weights_generator
())
vision_lang_adapter_weights
)
vision_lang_adpter_dict
=
dict
(
self
.
vision_language_adapter
.
named_parameters
())
for
name
,
loaded_weight
in
vision_lang_adapter_weights
:
# cut 'vision_language_adapter.'
name
=
'.'
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_lang_adpter_dict
[
name
]
default_weight_loader
(
param
,
loaded_weight
)
# Vision encoder
# Vision encoder
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment