Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
193069d1
Unverified
Commit
193069d1
authored
Jan 21, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 20, 2026
Browse files
[5/N] Initialize MM components in context managers (Q-Z) (#32695)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
f0feb1cf
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
177 additions
and
167 deletions
+177
-167
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+13
-15
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+52
-45
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+14
-6
vllm/model_executor/models/radio.py
vllm/model_executor/models/radio.py
+0
-1
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+14
-15
vllm/model_executor/models/skyworkr1v.py
vllm/model_executor/models/skyworkr1v.py
+18
-21
vllm/model_executor/models/tarsier.py
vllm/model_executor/models/tarsier.py
+37
-36
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+14
-13
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+15
-15
No files found.
vllm/model_executor/models/qwen2_audio.py
View file @
193069d1
...
...
@@ -334,20 +334,21 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
audio_tower
=
Qwen2AudioEncoder
(
config
.
audio_config
)
self
.
multi_modal_projector
=
Qwen2AudioMultiModalProjector
(
config
.
audio_config
.
d_model
,
config
.
text_config
.
hidden_size
)
self
.
quant_config
=
quant_config
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen2ForCausalLM"
],
)
with
self
.
_mark_tower_model
(
vllm_config
,
"audio"
):
self
.
audio_tower
=
Qwen2AudioEncoder
(
config
.
audio_config
)
self
.
multi_modal_projector
=
Qwen2AudioMultiModalProjector
(
config
.
audio_config
.
d_model
,
config
.
text_config
.
hidden_size
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen2ForCausalLM"
],
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
...
...
@@ -441,9 +442,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
masked_audio_features
,
audio_output_lengths
.
flatten
().
tolist
()
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
None
:
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
193069d1
...
...
@@ -1612,32 +1612,14 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
thinker_config
self
.
multimodal_config
=
multimodal_config
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
thinker_config
.
audio_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"audio_tower"
),
)
self
.
visual
=
Qwen3Omni_VisionTransformer
(
vision_config
=
thinker_config
.
vision_config
,
norm_eps
=
getattr
(
thinker_config
.
text_config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
multimodal_config
=
multimodal_config
,
)
self
.
quant_config
=
quant_config
self
.
language_model
=
Qwen3MoeLLMForCausalLM
(
vllm_config
=
vllm_config
.
with_hf_config
(
thinker_config
.
text_config
,
architectures
=
[
"Qwen3MoeForCausalLM"
]
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
with
self
.
_mark_tower_model
(
vllm_config
,
"audio"
):
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
thinker_config
.
audio_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"audio_tower"
),
)
self
.
use_deepstack
=
hasattr
(
thinker_config
.
vision_config
,
"deepstack_visual_indexes"
...
...
@@ -1647,22 +1629,48 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if
self
.
use_deepstack
else
0
)
# register buffer for deepstack
self
.
deepstack_input_embeds
=
(
[
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
thinker_config
.
text_config
.
hidden_size
,
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
if
self
.
use_deepstack
else
None
)
self
.
visual_dim
=
thinker_config
.
vision_config
.
out_hidden_size
self
.
multiscale_dim
=
self
.
visual_dim
*
self
.
deepstack_num_level
def
_get_deepstack_input_embeds
(
self
,
num_tokens
:
int
)
->
IntermediateTensors
:
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
self
.
visual
=
Qwen3Omni_VisionTransformer
(
vision_config
=
thinker_config
.
vision_config
,
norm_eps
=
getattr
(
thinker_config
.
text_config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
multimodal_config
=
multimodal_config
,
)
# register buffer for deepstack
if
self
.
use_deepstack
:
self
.
deepstack_input_embeds
=
[
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
thinker_config
.
text_config
.
hidden_size
,
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
Qwen3MoeLLMForCausalLM
(
vllm_config
=
vllm_config
.
with_hf_config
(
thinker_config
.
text_config
,
architectures
=
[
"Qwen3MoeForCausalLM"
],
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_get_deepstack_input_embeds
(
self
,
num_tokens
:
int
,
)
->
IntermediateTensors
|
None
:
if
not
getattr
(
self
,
"deepstack_input_embeds"
,
None
):
return
None
# If vision tower is skipped
# get deepstack_input_embeds from buffer, and clear the buffer
return
IntermediateTensors
(
{
...
...
@@ -1674,6 +1682,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
def
_set_deepstack_input_embeds
(
self
,
deepstack_input_embeds
:
torch
.
Tensor
)
->
None
:
if
not
getattr
(
self
,
"deepstack_input_embeds"
,
None
):
return
# set deepstack_input_embeds to buffer
num_tokens
=
deepstack_input_embeds
.
size
(
1
)
if
num_tokens
>
self
.
deepstack_input_embeds
[
0
].
size
(
0
):
...
...
@@ -1692,6 +1703,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
def
_clear_deepstack_input_embeds
(
self
,
num_tokens
:
int
)
->
None
:
if
not
getattr
(
self
,
"deepstack_input_embeds"
,
None
):
return
# clear deepstack_input_embeds in buffer
if
num_tokens
>
0
:
for
idx
in
range
(
self
.
deepstack_num_level
):
...
...
@@ -1726,9 +1740,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
return
mm_input_by_modality
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
|
None
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
mm_input_by_modality
:
...
...
@@ -1844,11 +1855,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
if
(
self
.
use_deepstack
and
inputs_embeds
is
not
None
and
get_pp_group
().
is_first_rank
):
if
inputs_embeds
is
not
None
and
get_pp_group
().
is_first_rank
:
deepstack_input_embeds
=
self
.
_get_deepstack_input_embeds
(
inputs_embeds
.
size
(
0
)
)
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
193069d1
...
...
@@ -1321,7 +1321,13 @@ class Qwen3VLForConditionalGeneration(
num_layers
=
len
(
self
.
language_model
.
model
.
layers
)
return
(
2
,
num_layers
//
2
,
num_layers
-
3
)
def
_get_deepstack_input_embeds
(
self
,
num_tokens
:
int
)
->
IntermediateTensors
:
def
_get_deepstack_input_embeds
(
self
,
num_tokens
:
int
,
)
->
IntermediateTensors
|
None
:
if
not
getattr
(
self
,
"deepstack_input_embeds"
,
None
):
return
None
# If vision tower is skipped
# get deepstack_input_embeds from buffer, and clear the buffer
return
IntermediateTensors
(
{
...
...
@@ -1333,6 +1339,9 @@ class Qwen3VLForConditionalGeneration(
)
def
_set_deepstack_input_embeds
(
self
,
deepstack_input_embeds
:
torch
.
Tensor
)
->
None
:
if
not
getattr
(
self
,
"deepstack_input_embeds"
,
None
):
return
# set deepstack_input_embeds to buffer
num_tokens
=
deepstack_input_embeds
.
size
(
1
)
if
num_tokens
>
self
.
deepstack_input_embeds
[
0
].
size
(
0
):
...
...
@@ -1351,6 +1360,9 @@ class Qwen3VLForConditionalGeneration(
)
def
_clear_deepstack_input_embeds
(
self
,
num_tokens
:
int
)
->
None
:
if
not
getattr
(
self
,
"deepstack_input_embeds"
,
None
):
return
# clear deepstack_input_embeds in buffer
if
num_tokens
>
0
:
for
idx
in
range
(
self
.
deepstack_num_level
):
...
...
@@ -2037,11 +2049,7 @@ class Qwen3VLForConditionalGeneration(
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
if
(
self
.
use_deepstack
and
inputs_embeds
is
not
None
and
get_pp_group
().
is_first_rank
):
if
inputs_embeds
is
not
None
and
get_pp_group
().
is_first_rank
:
deepstack_input_embeds
=
self
.
_get_deepstack_input_embeds
(
inputs_embeds
.
size
(
0
)
)
...
...
vllm/model_executor/models/radio.py
View file @
193069d1
...
...
@@ -620,7 +620,6 @@ class RadioInternVisionModel(nn.Module):
x
:
torch
.
Tensor
,
imgs_sizes
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
FloatTensor
:
assert
self
.
patch_generator
is
not
None
hidden_states
=
self
.
patch_generator
(
x
,
imgs_sizes
=
imgs_sizes
)
attn_mask
=
None
if
imgs_sizes
is
not
None
and
len
(
imgs_sizes
)
>
1
:
...
...
vllm/model_executor/models/siglip.py
View file @
193069d1
...
...
@@ -1033,20 +1033,22 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self
.
text_embed_dim
=
text_config
.
hidden_size
self
.
vision_embed_dim
=
vision_config
.
hidden_size
self
.
text_projection_size
=
text_config
.
projection_size
self
.
text_model
=
SiglipTextTransformer
(
text_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"text_model"
),
)
self
.
vision_model
=
SiglipVisionTransformer
(
vision_config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
text_model
=
SiglipTextTransformer
(
text_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"text_model"
),
)
self
.
text_projection_size
=
text_config
.
projection_size
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_model
=
SiglipVisionTransformer
(
vision_config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
...
...
@@ -1155,9 +1157,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
return
self
.
get_image_features
(
pixel_values
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
text_model
def
_embed_text_input_ids
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/skyworkr1v.py
View file @
193069d1
...
...
@@ -674,24 +674,26 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
downsample_ratio
=
config
.
downsample_ratio
self
.
ps_version
=
config
.
ps_version
self
.
llm_arch_name
=
config
.
text_config
.
architectures
[
0
]
self
.
is_mono
=
self
.
llm_arch_name
==
"SkyworkLM2VEForCausalLM"
self
.
vision_model
=
self
.
_init_vision_model
(
config
,
quant_config
=
quant_config
,
is_mono
=
self
.
is_mono
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
)
llm_arch_name
=
config
.
text_config
.
architectures
[
0
]
self
.
is_mono
=
llm_arch_name
==
"SkyworkLM2VEForCausalLM"
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_model
=
self
.
_init_vision_model
(
config
,
quant_config
=
quant_config
,
is_mono
=
self
.
is_mono
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"mlp1"
)
)
self
.
mlp1
=
self
.
_init_mlp1
(
config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"mlp1"
)
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
img_context_token_id
=
None
self
.
visual_token_mask
=
None
...
...
@@ -838,8 +840,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
assert
self
.
vision_model
is
not
None
image_embeds
=
self
.
extract_feature
(
image_input
[
"pixel_values_flat"
])
num_patches
=
image_input
[
"num_patches"
]
...
...
@@ -867,9 +867,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else
:
self
.
visual_token_mask
=
None
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
...
...
vllm/model_executor/models/tarsier.py
View file @
193069d1
...
...
@@ -423,38 +423,43 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
# Storing the Tarsier-specific HF config
self
.
vision_tower
=
init_vision_tower_for_tarsier
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
projector_bias
=
getattr
(
config
,
"multimodal_projector_bias"
,
True
)
self
.
multi_modal_projector
=
TarsierMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
,
multimodal_projector_bias
=
projector_bias
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
),
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
# Use text_config from Tarsier's main config
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
register_buffer
(
"image_newline_idx_tensor"
,
torch
.
tensor
([
config
.
image_newline_idx
],
dtype
=
torch
.
long
),
persistent
=
False
,
)
self
.
register_buffer
(
"image_new_idx_tensor"
,
torch
.
tensor
([
config
.
image_new_idx
],
dtype
=
torch
.
long
),
persistent
=
False
,
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_tower
=
init_vision_tower_for_tarsier
(
config
,
quant_config
=
quant_config
,
multimodal_config
=
multimodal_config
,
require_post_norm
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
)
projector_bias
=
getattr
(
config
,
"multimodal_projector_bias"
,
True
)
self
.
multi_modal_projector
=
TarsierMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
,
multimodal_projector_bias
=
projector_bias
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
),
)
self
.
register_buffer
(
"image_newline_idx_tensor"
,
torch
.
tensor
([
config
.
image_newline_idx
],
dtype
=
torch
.
long
),
persistent
=
False
,
)
self
.
register_buffer
(
"image_new_idx_tensor"
,
torch
.
tensor
([
config
.
image_new_idx
],
dtype
=
torch
.
long
),
persistent
=
False
,
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
# Use text_config from Tarsier's main config
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
...
...
@@ -547,7 +552,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self
,
inputs
:
TarsierImagePixelInputs
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"pixel_values"
]
image_features_selected
=
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
...
...
@@ -575,11 +579,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
"Incorrect type of image_embeds. "
f
"Got type:
{
type
(
projected_features
)
}
. "
)
assert
self
.
vision_tower
is
not
None
return
self
.
_process_image_pixels
(
image_input
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
_process_image_pixels
(
image_input
)
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/ultravox.py
View file @
193069d1
...
...
@@ -543,7 +543,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
assert
self
.
multi_modal_config
self
.
secondary_weights
=
[]
self
.
audio_tower
=
ModifiedWhisperEncoder
(
config
.
audio_config
)
if
config
.
audio_model_id
is
not
None
:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
...
...
@@ -554,15 +553,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
prefix
=
"audio_tower."
,
)
)
if
config
.
num_projector_layers
>
0
:
self
.
multi_modal_projector
=
UltravoxTransformerProjector
(
config
)
else
:
self
.
multi_modal_projector
=
UltravoxFeedForwardProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
wrapped_model_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
if
config
.
text_model_id
is
not
None
:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
...
...
@@ -574,6 +564,20 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
)
)
with
self
.
_mark_tower_model
(
vllm_config
,
"audio"
):
self
.
audio_tower
=
ModifiedWhisperEncoder
(
config
.
audio_config
)
if
config
.
num_projector_layers
>
0
:
self
.
multi_modal_projector
=
UltravoxTransformerProjector
(
config
)
else
:
self
.
multi_modal_projector
=
UltravoxFeedForwardProjector
(
config
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
wrapped_model_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
@@ -681,9 +685,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
]
return
flattened_embeddings
.
split
(
embed_lens
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
None
:
...
...
vllm/model_executor/models/voxtral.py
View file @
193069d1
...
...
@@ -366,22 +366,22 @@ class VoxtralForConditionalGeneration(
self
.
config
=
config
self
.
downsample_factor
=
self
.
config
.
audio_config
.
downsample_factor
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
whisper_encoder
=
VoxtralEncoderModel
(
vllm_config
.
with_hf_config
(
config
.
audio_config
),
prefix
=
maybe_prefix
(
prefix
,
"whisper_encoder"
),
)
self
.
audio_language_adapter
=
AudioLanguageAdapter
(
hidden_size
=
config
.
audio_config
.
d_model
*
self
.
downsample_factor
,
dim
=
config
.
text_config
.
hidden_size
,
)
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
with
self
.
_mark_tower_model
(
vllm_config
,
"audio"
):
self
.
whisper_encoder
=
VoxtralEncoderModel
(
vllm_config
.
with_hf_config
(
config
.
audio_config
),
prefix
=
maybe_prefix
(
prefix
,
"whisper_encoder"
),
)
self
.
audio_language_adapter
=
AudioLanguageAdapter
(
hidden_size
=
config
.
audio_config
.
d_model
*
self
.
downsample_factor
,
dim
=
config
.
text_config
.
hidden_size
,
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""Get module prefix for multimodal models to filter LoRA modules."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment