Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e1a34c3a
Unverified
Commit
e1a34c3a
authored
Jan 20, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 20, 2026
Browse files
[2/N] Initialize MM components in context managers (E-H) (#32641)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
148117ea
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
163 additions
and
191 deletions
+163
-191
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+0
-2
vllm/model_executor/models/aya_vision.py
vllm/model_executor/models/aya_vision.py
+0
-1
vllm/model_executor/models/cohere2_vision.py
vllm/model_executor/models/cohere2_vision.py
+0
-2
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+21
-23
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+15
-18
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+20
-23
vllm/model_executor/models/gemma3n_mm.py
vllm/model_executor/models/gemma3n_mm.py
+28
-30
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+15
-16
vllm/model_executor/models/glmasr.py
vllm/model_executor/models/glmasr.py
+19
-21
vllm/model_executor/models/granite_speech.py
vllm/model_executor/models/granite_speech.py
+21
-22
vllm/model_executor/models/hunyuan_vision.py
vllm/model_executor/models/hunyuan_vision.py
+11
-15
vllm/model_executor/models/hyperclovax_vision.py
vllm/model_executor/models/hyperclovax_vision.py
+13
-18
No files found.
vllm/model_executor/models/aria.py
View file @
e1a34c3a
...
@@ -590,8 +590,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -590,8 +590,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def
_process_image_input
(
def
_process_image_input
(
self
,
image_input
:
AriaImagePixelInputs
self
,
image_input
:
AriaImagePixelInputs
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
image_input
[
"pixel_values"
]
pixel_values
=
image_input
[
"pixel_values"
]
pixel_mask
=
image_input
[
"pixel_mask"
]
pixel_mask
=
image_input
[
"pixel_mask"
]
...
...
vllm/model_executor/models/aya_vision.py
View file @
e1a34c3a
...
@@ -382,7 +382,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
...
@@ -382,7 +382,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def
_process_image_input
(
def
_process_image_input
(
self
,
image_input
:
AyaVisionImagePixelInputs
,
**
kwargs
self
,
image_input
:
AyaVisionImagePixelInputs
,
**
kwargs
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
image_input
[
"pixel_values"
]
pixel_values
=
image_input
[
"pixel_values"
]
num_patches
=
image_input
[
"num_patches"
]
num_patches
=
image_input
[
"num_patches"
]
image_features
=
self
.
_image_pixels_to_features
(
image_features
=
self
.
_image_pixels_to_features
(
...
...
vllm/model_executor/models/cohere2_vision.py
View file @
e1a34c3a
...
@@ -391,8 +391,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
...
@@ -391,8 +391,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
Returns:
Returns:
List of flattened image embeddings, one per image
List of flattened image embeddings, one per image
"""
"""
assert
self
.
vision_tower
is
not
None
,
"Vision tower is required"
pixel_values
=
image_input
[
"pixel_values"
]
pixel_values
=
image_input
[
"pixel_values"
]
num_patches
=
image_input
[
"num_patches"
]
num_patches
=
image_input
[
"num_patches"
]
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
e1a34c3a
...
@@ -1303,27 +1303,28 @@ class Ernie4_5_VLMoeForConditionalGeneration(
...
@@ -1303,27 +1303,28 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
vision_model
=
Ernie4_5_VisionTransformer
(
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
config
.
vision_config
,
self
.
vision_model
=
Ernie4_5_VisionTransformer
(
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
config
.
vision_config
,
quant_config
=
quant_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
multimodal_config
=
multimodal_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
multimodal_config
=
multimodal_config
,
)
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
)
self
.
language_model
=
Ernie4_5_VLMoeForCausalLM
(
self
.
resampler_model
=
VariableResolutionResamplerModel
(
vllm_config
=
vllm_config
,
self
.
config
.
pixel_hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
self
.
config
.
hidden_size
,
)
self
.
config
.
spatial_conv_size
,
self
.
config
.
temporal_conv_size
,
config
=
self
.
config
,
prefix
=
maybe_prefix
(
prefix
,
"resampler_model"
),
)
self
.
resampler_model
=
VariableResolutionResamplerModel
(
with
self
.
_mark_language_model
(
vllm_config
):
self
.
config
.
pixel_hidden_size
,
self
.
language_model
=
Ernie4_5_VLMoeForCausalLM
(
self
.
config
.
hidden_size
,
vllm_config
=
vllm_config
,
self
.
config
.
spatial_conv_size
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
self
.
config
.
temporal_conv_size
,
)
config
=
self
.
config
,
prefix
=
maybe_prefix
(
prefix
,
"resampler_model"
),
)
self
.
visual_token_mask
=
None
self
.
visual_token_mask
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
...
@@ -1522,9 +1523,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
...
@@ -1522,9 +1523,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
return
llm_positions
,
mrope_position_delta
return
llm_positions
,
mrope_position_delta
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
self
,
**
kwargs
:
object
)
->
Ernie4_5_VLImageInputs
|
None
:
)
->
Ernie4_5_VLImageInputs
|
None
:
...
...
vllm/model_executor/models/fuyu.py
View file @
e1a34c3a
...
@@ -287,16 +287,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -287,16 +287,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self
.
image_token_id
=
_IMAGE_TOKEN_ID
self
.
image_token_id
=
_IMAGE_TOKEN_ID
self
.
image_feature_size
=
config
.
patch_size
**
2
*
config
.
num_channels
self
.
image_feature_size
=
config
.
patch_size
**
2
*
config
.
num_channels
self
.
vision_embed_tokens
=
ColumnParallelLinear
(
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
image_feature_size
,
self
.
vision_embed_tokens
=
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
image_feature_size
,
quant_config
=
quant_config
,
config
.
hidden_size
,
gather_output
=
True
,
quant_config
=
quant_config
,
)
gather_output
=
True
,
self
.
language_model
=
PersimmonForCausalLM
(
)
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
with
self
.
_mark_language_model
(
vllm_config
):
)
self
.
language_model
=
PersimmonForCausalLM
(
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
self
.
language_model
.
make_empty_intermediate_tensors
)
)
...
@@ -323,14 +327,10 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -323,14 +327,10 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches_flat
=
image_input
[
"image_patches_flat"
]
image_patches_flat
=
image_input
[
"image_patches_flat"
]
patches_per_image
=
image_input
[
"patches_per_image"
]
patches_per_image
=
image_input
[
"patches_per_image"
]
assert
self
.
vision_embed_tokens
is
not
None
vision_embeddings_flat
,
_
=
self
.
vision_embed_tokens
(
image_patches_flat
)
vision_embeddings_flat
,
_
=
self
.
vision_embed_tokens
(
image_patches_flat
)
return
vision_embeddings_flat
.
split
(
patches_per_image
.
tolist
(),
dim
=
0
)
return
vision_embeddings_flat
.
split
(
patches_per_image
.
tolist
(),
dim
=
0
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
...
@@ -361,10 +361,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -361,10 +361,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
)
->
torch
.
Tensor
|
None
:
logits
=
self
.
language_model
.
logits_processor
(
return
self
.
language_model
.
compute_logits
(
hidden_states
)
self
.
language_model
.
lm_head
,
hidden_states
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/model_executor/models/gemma3_mm.py
View file @
e1a34c3a
...
@@ -522,25 +522,27 @@ class Gemma3ForConditionalGeneration(
...
@@ -522,25 +522,27 @@ class Gemma3ForConditionalGeneration(
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
vision_tower
=
SiglipVisionModel
(
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
config
.
vision_config
,
self
.
vision_tower
=
SiglipVisionModel
(
quant_config
,
config
.
vision_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
quant_config
,
)
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
self
.
multi_modal_projector
=
Gemma3MultiModalProjector
(
config
)
)
self
.
multi_modal_projector
=
Gemma3MultiModalProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
with
self
.
_mark_language_model
(
vllm_config
):
hf_config
=
config
.
text_config
,
self
.
language_model
=
init_vllm_registered_model
(
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
vllm_config
=
vllm_config
,
architectures
=
[
"Gemma3ForCausalLM"
],
hf_config
=
config
.
text_config
,
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
architectures
=
[
"Gemma3ForCausalLM"
],
)
if
hasattr
(
self
.
language_model
,
"logits_processor"
):
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
# The logits processor can be unset if we're using
if
hasattr
(
self
.
language_model
,
"logits_processor"
):
# automatic conversion to pooling model.
# The logits processor can be unset if we're using
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
# automatic conversion to pooling model.
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
self
.
language_model
.
make_empty_intermediate_tensors
...
@@ -579,8 +581,6 @@ class Gemma3ForConditionalGeneration(
...
@@ -579,8 +581,6 @@ class Gemma3ForConditionalGeneration(
self
,
self
,
image_input
:
Gemma3ImageInputs
,
image_input
:
Gemma3ImageInputs
,
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
image_input
[
"pixel_values"
]
pixel_values
=
image_input
[
"pixel_values"
]
num_patches
=
image_input
[
"num_patches"
]
num_patches
=
image_input
[
"num_patches"
]
...
@@ -592,9 +592,6 @@ class Gemma3ForConditionalGeneration(
...
@@ -592,9 +592,6 @@ class Gemma3ForConditionalGeneration(
return
[
e
.
flatten
(
0
,
1
)
for
e
in
image_embeds
.
split
(
num_patches
.
tolist
())]
return
[
e
.
flatten
(
0
,
1
)
for
e
in
image_embeds
.
split
(
num_patches
.
tolist
())]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
...
...
vllm/model_executor/models/gemma3n_mm.py
View file @
e1a34c3a
...
@@ -503,31 +503,35 @@ class Gemma3nForConditionalGeneration(
...
@@ -503,31 +503,35 @@ class Gemma3nForConditionalGeneration(
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
vision_tower
=
AutoModel
.
from_config
(
config
=
config
.
vision_config
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
audio_tower
=
AutoModel
.
from_config
(
config
=
config
.
audio_config
)
self
.
vision_tower
=
AutoModel
.
from_config
(
config
=
config
.
vision_config
)
self
.
embed_vision
=
Gemma3nMultimodalEmbedder
(
self
.
embed_vision
=
Gemma3nMultimodalEmbedder
(
config
.
vision_config
,
config
.
text_config
config
.
vision_config
,
config
.
text_config
)
)
self
.
embed_audio
=
Gemma3nMultimodalEmbedder
(
config
.
audio_config
,
config
.
text_config
)
self
.
language_model
:
nn
.
Module
=
init_vllm_registered_model
(
with
self
.
_mark_tower_model
(
vllm_config
,
"audio"
):
vllm_config
=
vllm_config
,
self
.
audio_tower
=
AutoModel
.
from_config
(
config
=
config
.
audio_config
)
hf_config
=
config
.
text_config
,
self
.
embed_audio
=
Gemma3nMultimodalEmbedder
(
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
config
.
audio_config
,
config
.
text_config
architectures
=
[
"Gemma3nForCausalLM"
],
)
)
self
.
language_model
=
cast
(
Gemma3nForCausalLM
,
self
.
language_model
)
with
self
.
_mark_language_model
(
vllm_config
):
# NOTE (NickLucche) In order to be compatible with cudagraph, the
self
.
language_model
:
Gemma3nForCausalLM
=
init_vllm_registered_model
(
# buffer needs to be consistent, so we pre-allocate here.
vllm_config
=
vllm_config
,
self
.
per_layer_embeddings
=
torch
.
zeros
(
hf_config
=
config
.
text_config
,
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
self
.
config
.
text_config
.
num_hidden_layers
,
architectures
=
[
"Gemma3nForCausalLM"
],
self
.
config
.
text_config
.
hidden_size_per_layer_input
,
)
device
=
self
.
language_model
.
model
.
embed_tokens
.
weight
.
device
,
dtype
=
self
.
language_model
.
model
.
embed_tokens
.
weight
.
dtype
,
# NOTE (NickLucche) In order to be compatible with cudagraph, the
)
# buffer needs to be consistent, so we pre-allocate here.
self
.
per_layer_embeddings
=
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
self
.
config
.
text_config
.
num_hidden_layers
,
self
.
config
.
text_config
.
hidden_size_per_layer_input
,
device
=
self
.
language_model
.
model
.
embed_tokens
.
weight
.
device
,
dtype
=
self
.
language_model
.
model
.
embed_tokens
.
weight
.
dtype
,
)
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
self
,
**
kwargs
:
object
...
@@ -583,8 +587,6 @@ class Gemma3nForConditionalGeneration(
...
@@ -583,8 +587,6 @@ class Gemma3nForConditionalGeneration(
self
,
self
,
image_input
:
Gemma3nImageInputs
,
image_input
:
Gemma3nImageInputs
,
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
image_input
[
"pixel_values"
]
pixel_values
=
image_input
[
"pixel_values"
]
vision_outputs
=
self
.
vision_tower
(
vision_outputs
=
self
.
vision_tower
(
pixel_values
=
pixel_values
,
do_pooling
=
False
,
return_dict
=
True
pixel_values
=
pixel_values
,
do_pooling
=
False
,
return_dict
=
True
...
@@ -609,7 +611,6 @@ class Gemma3nForConditionalGeneration(
...
@@ -609,7 +611,6 @@ class Gemma3nForConditionalGeneration(
self
,
self
,
audio_input
:
Gemma3nAudioInputs
,
audio_input
:
Gemma3nAudioInputs
,
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
assert
self
.
audio_tower
is
not
None
# Run on padded features to enable batching
# Run on padded features to enable batching
input_features
=
audio_input
[
"input_features_padded"
].
squeeze
(
1
)
input_features
=
audio_input
[
"input_features_padded"
].
squeeze
(
1
)
input_features_mask
=
audio_input
[
"input_features_mask"
].
squeeze
(
1
)
input_features_mask
=
audio_input
[
"input_features_mask"
].
squeeze
(
1
)
...
@@ -651,9 +652,6 @@ class Gemma3nForConditionalGeneration(
...
@@ -651,9 +652,6 @@ class Gemma3nForConditionalGeneration(
# Return a list of embeddings instead of a batched tensor
# Return a list of embeddings instead of a batched tensor
return
audio_features
.
unbind
(
0
)
return
audio_features
.
unbind
(
0
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
mm_input_by_modality
is
None
:
if
mm_input_by_modality
is
None
:
...
...
vllm/model_executor/models/glm4_1v.py
View file @
e1a34c3a
...
@@ -1434,13 +1434,14 @@ class Glm4vForConditionalGeneration(
...
@@ -1434,13 +1434,14 @@ class Glm4vForConditionalGeneration(
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
visual
=
Glm4vVisionTransformer
(
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
config
.
vision_config
,
self
.
visual
=
Glm4vVisionTransformer
(
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
config
.
vision_config
,
quant_config
=
quant_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
multimodal_config
=
multimodal_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
multimodal_config
=
multimodal_config
,
)
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
if
config
.
model_type
==
"glm4v"
:
if
config
.
model_type
==
"glm4v"
:
architectures
=
[
"Glm4ForCausalLM"
]
architectures
=
[
"Glm4ForCausalLM"
]
...
@@ -1449,12 +1450,13 @@ class Glm4vForConditionalGeneration(
...
@@ -1449,12 +1450,13 @@ class Glm4vForConditionalGeneration(
else
:
else
:
architectures
=
None
architectures
=
None
self
.
language_model
=
init_vllm_registered_model
(
with
self
.
_mark_language_model
(
vllm_config
):
vllm_config
=
vllm_config
,
self
.
language_model
=
init_vllm_registered_model
(
hf_config
=
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
hf_config
=
config
.
text_config
,
architectures
=
architectures
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
architectures
=
architectures
,
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
self
.
language_model
.
make_empty_intermediate_tensors
...
@@ -1578,9 +1580,6 @@ class Glm4vForConditionalGeneration(
...
@@ -1578,9 +1580,6 @@ class Glm4vForConditionalGeneration(
)
)
return
mm_input_by_modality
return
mm_input_by_modality
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
|
None
:
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
|
None
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
mm_input_by_modality
:
if
not
mm_input_by_modality
:
...
...
vllm/model_executor/models/glmasr.py
View file @
e1a34c3a
...
@@ -944,26 +944,27 @@ class GlmAsrForConditionalGeneration(
...
@@ -944,26 +944,27 @@ class GlmAsrForConditionalGeneration(
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
# Use optimized vLLM native encoder
self
.
audio_tower
=
GlmAsrEncoder
(
config
.
audio_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"audio_tower"
),
)
self
.
multi_modal_projector
=
GlmAsrMultiModalProjector
(
config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
),
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
language_model
=
init_vllm_registered_model
(
with
self
.
_mark_tower_model
(
vllm_config
,
"audio"
):
vllm_config
=
vllm_config
,
self
.
audio_tower
=
GlmAsrEncoder
(
hf_config
=
config
.
text_config
,
config
.
audio_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
quant_config
=
quant_config
,
architectures
=
[
"LlamaForCausalLM"
],
prefix
=
maybe_prefix
(
prefix
,
"audio_tower"
),
)
)
self
.
multi_modal_projector
=
GlmAsrMultiModalProjector
(
config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
),
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"LlamaForCausalLM"
],
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
self
.
language_model
.
make_empty_intermediate_tensors
...
@@ -1063,9 +1064,6 @@ class GlmAsrForConditionalGeneration(
...
@@ -1063,9 +1064,6 @@ class GlmAsrForConditionalGeneration(
)
)
return
_group_audio_embeddings
(
chunk_embeddings
,
chunk_counts
)
return
_group_audio_embeddings
(
chunk_embeddings
,
chunk_counts
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
None
:
if
audio_input
is
None
:
...
...
vllm/model_executor/models/granite_speech.py
View file @
e1a34c3a
...
@@ -597,27 +597,29 @@ class GraniteSpeechForConditionalGeneration(
...
@@ -597,27 +597,29 @@ class GraniteSpeechForConditionalGeneration(
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
# The language model is typically a Granite LLM
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
# The language model is typically a Granite LLM
vllm_config
=
vllm_config
,
self
.
language_model
=
init_vllm_registered_model
(
hf_config
=
config
.
text_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
hf_config
=
config
.
text_config
,
)
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
# Conformer encoder
with
self
.
_mark_tower_model
(
vllm_config
,
"audio"
):
self
.
encoder
=
GraniteSpeechCTCEncoder
(
# Conformer encoder
config
=
config
.
encoder_config
,
self
.
encoder
=
GraniteSpeechCTCEncoder
(
quant_config
=
quant_config
,
config
=
config
.
encoder_config
,
prefix
=
f
"
{
prefix
}
.encoder"
,
quant_config
=
quant_config
,
)
prefix
=
f
"
{
prefix
}
.encoder"
,
)
# Blip2 QFormer
# Blip2 QFormer
self
.
projector
=
GraniteSpeechEncoderProjector
(
self
.
projector
=
GraniteSpeechEncoderProjector
(
config
=
config
,
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.projector"
,
prefix
=
f
"
{
prefix
}
.projector"
,
)
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
self
.
language_model
.
make_empty_intermediate_tensors
...
@@ -770,9 +772,6 @@ class GraniteSpeechForConditionalGeneration(
...
@@ -770,9 +772,6 @@ class GraniteSpeechForConditionalGeneration(
# Split variable length features into a tuple
# Split variable length features into a tuple
return
torch
.
split
(
masked_embeds
,
audio_input
[
"audio_embed_sizes"
])
return
torch
.
split
(
masked_embeds
,
audio_input
[
"audio_embed_sizes"
])
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
def
embed_multimodal
(
self
,
self
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
...
vllm/model_executor/models/hunyuan_vision.py
View file @
e1a34c3a
...
@@ -877,7 +877,7 @@ class HunYuanVLForConditionalGeneration(
...
@@ -877,7 +877,7 @@ class HunYuanVLForConditionalGeneration(
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
if
multimodal_config
.
get_limit_per_prompt
(
"image"
):
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
}
):
attn_backend_override
=
(
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
if
multimodal_config
is
not
None
...
@@ -890,17 +890,16 @@ class HunYuanVLForConditionalGeneration(
...
@@ -890,17 +890,16 @@ class HunYuanVLForConditionalGeneration(
multimodal_config
=
multimodal_config
,
multimodal_config
=
multimodal_config
,
attn_backend_override
=
attn_backend_override
,
attn_backend_override
=
attn_backend_override
,
)
)
else
:
self
.
visual
=
None
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model.model"
),
prefix
=
maybe_prefix
(
prefix
,
"language_model.model"
),
architectures
=
[
architectures
=
[
"HunYuanDenseV1ForCausalLM"
,
"HunYuanDenseV1ForCausalLM"
,
"HunYuanMoEV1ForCausalLM"
,
"HunYuanMoEV1ForCausalLM"
,
],
],
)
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
self
.
language_model
.
make_empty_intermediate_tensors
...
@@ -970,9 +969,6 @@ class HunYuanVLForConditionalGeneration(
...
@@ -970,9 +969,6 @@ class HunYuanVLForConditionalGeneration(
)
)
return
mm_input_by_modality
return
mm_input_by_modality
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
mm_input_by_modality
:
if
not
mm_input_by_modality
:
...
...
vllm/model_executor/models/hyperclovax_vision.py
View file @
e1a34c3a
...
@@ -15,7 +15,6 @@ from einops import rearrange
...
@@ -15,7 +15,6 @@ from einops import rearrange
from
timm.layers
import
LayerNorm
,
LayerNorm2d
from
timm.layers
import
LayerNorm
,
LayerNorm2d
from
timm.models.regnet
import
RegStage
from
timm.models.regnet
import
RegStage
from
transformers
import
BatchFeature
,
CLIPVisionConfig
,
SiglipVisionConfig
from
transformers
import
BatchFeature
,
CLIPVisionConfig
,
SiglipVisionConfig
from
transformers.modeling_utils
import
no_init_weights
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
,
MultiModalConfig
from
vllm.config.multimodal
import
BaseDummyOptions
,
MultiModalConfig
...
@@ -625,8 +624,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -625,8 +624,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config
,
vision_config
config
,
vision_config
)
)
# init models & parameters
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
with
no_init_weights
():
# weight will be loaded in from_pretrained
self
.
vision_model
=
init_vision_tower_for_hcxvision
(
self
.
vision_model
=
init_vision_tower_for_hcxvision
(
vision_config
,
vision_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -635,20 +633,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -635,20 +633,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
require_post_norm
=
False
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
)
)
self
.
mm_projector
=
self
.
_init_mm_projector
(
config
,
text_config
,
vision_config
)
self
.
mm_projector
=
self
.
_init_mm_projector
(
config
,
text_config
,
vision_config
)
self
.
lm_head_vocab_size
=
getattr
(
if
config
.
anyres
:
text_config
,
"padded_vocab_size"
,
text_config
.
vocab_size
self
.
image_newline
=
nn
.
Parameter
(
)
torch
.
empty
(
text_config
.
hidden_size
,
dtype
=
self
.
dtype
)
self
.
language_model
=
init_vllm_registered_model
(
)
vllm_config
=
vllm_config
,
hf_config
=
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
if
config
.
anyres
:
with
self
.
_mark_language_model
(
vllm_config
):
self
.
image_newline
=
nn
.
Parameter
(
self
.
language_model
=
init_vllm_registered_model
(
torch
.
empty
(
text_config
.
hidden_size
,
dtype
=
self
.
dtype
)
vllm_config
=
vllm_config
,
hf_config
=
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
)
self
.
config
=
config
self
.
config
=
config
...
@@ -726,9 +724,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -726,9 +724,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
modalities
return
modalities
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
def
embed_multimodal
(
self
,
self
,
**
kwargs
:
object
,
**
kwargs
:
object
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment