Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a1825fe6
Commit
a1825fe6
authored
Sep 30, 2025
by
Roger Wang
Committed by
simon-mo
Sep 30, 2025
Browse files
[MM] Add text-only mode for Qwen3-VL (#26000)
Signed-off-by:
simon-mo
<
simon.mo@hey.com
>
parent
bab9231b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
26 deletions
+45
-26
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+25
-14
vllm/model_executor/models/qwen3_vl_moe.py
vllm/model_executor/models/qwen3_vl_moe.py
+20
-12
No files found.
vllm/model_executor/models/qwen3_vl.py
View file @
a1825fe6
...
...
@@ -1126,14 +1126,17 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
if
not
multimodal_config
.
get_limit_per_prompt
(
"image"
)
and
\
not
multimodal_config
.
get_limit_per_prompt
(
"video"
):
self
.
visual
=
None
else
:
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
self
.
language_model
=
Qwen3LLMForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
...
...
@@ -1149,11 +1152,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config
.
vision_config
.
deepstack_visual_indexes
)
if
self
.
use_deepstack
else
0
# register buffer for deepstack
self
.
deepstack_input_embeds
=
[
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
config
.
text_config
.
hidden_size
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
if
self
.
use_deepstack
else
None
if
self
.
use_deepstack
and
self
.
visual
is
not
None
:
self
.
deepstack_input_embeds
=
[
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
config
.
text_config
.
hidden_size
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
else
:
self
.
deepstack_input_embeds
=
None
self
.
visual_dim
=
config
.
vision_config
.
out_hidden_size
self
.
multiscale_dim
=
self
.
visual_dim
*
self
.
deepstack_num_level
...
...
@@ -1588,7 +1595,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
skip_prefixes
=
[]
if
self
.
visual
is
None
:
skip_prefixes
.
extend
([
"visual."
])
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skip_prefixes
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
...
...
vllm/model_executor/models/qwen3_vl_moe.py
View file @
a1825fe6
...
...
@@ -319,13 +319,17 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
if
not
multimodal_config
.
get_limit_per_prompt
(
"image"
)
and
\
not
multimodal_config
.
get_limit_per_prompt
(
"video"
):
self
.
visual
=
None
else
:
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
self
.
language_model
=
Qwen3MoeLLMForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
...
...
@@ -341,10 +345,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
config
.
vision_config
.
deepstack_visual_indexes
)
if
self
.
use_deepstack
else
0
# register buffer for deepstack
self
.
deepstack_input_embeds
=
[
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
config
.
text_config
.
hidden_size
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
if
self
.
use_deepstack
else
None
if
self
.
use_deepstack
and
self
.
visual
is
not
None
:
self
.
deepstack_input_embeds
=
[
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
config
.
text_config
.
hidden_size
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
else
:
self
.
deepstack_input_embeds
=
None
self
.
visual_dim
=
config
.
vision_config
.
out_hidden_size
self
.
multiscale_dim
=
self
.
visual_dim
*
self
.
deepstack_num_level
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment