Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b8199f60
Unverified
Commit
b8199f60
authored
Jan 13, 2026
by
Roger Wang
Committed by
GitHub
Jan 14, 2026
Browse files
[Model] Re-implement Qwen3Omni Audio Encoder (#32167)
Signed-off-by:
Roger Wang
<
hey@rogerw.io
>
parent
7e6f1238
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
428 additions
and
29 deletions
+428
-29
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+428
-29
No files found.
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
b8199f60
...
...
@@ -31,29 +31,34 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
packaging.version
import
Version
from
transformers
import
PretrainedConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe
import
(
Qwen3OmniMoeAudioEncoderConfig
,
Qwen3OmniMoeConfig
,
Qwen3OmniMoeThinkerConfig
,
)
from
transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe
import
(
Qwen3OmniMoeAudioEncoder
,
)
from
transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe
import
(
Qwen3OmniMoeProcessor
,
)
from
transformers.models.whisper
import
WhisperFeatureExtractor
# isort: off
from
transformers
import
PretrainedConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
# isort: on
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
MultiModalConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
_ACTIVATION_REGISTRY
from
vllm.model_executor.layers.attention.mm_encoder_attention
import
(
MMEncoderAttention
,
)
from
vllm.model_executor.layers.conv
import
Conv3dLayer
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
...
@@ -104,11 +109,6 @@ from .vision import (
get_vit_attn_backend
,
)
try
:
import
flash_attn
except
(
ImportError
,
ModuleNotFoundError
):
flash_attn
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -121,6 +121,415 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return
output_lengths
# ============= Audio Encoder Components =============
class
SinusoidsPositionEmbedding
(
nn
.
Module
):
"""Sinusoidal position embedding for audio encoder."""
def
__init__
(
self
,
length
:
int
,
channels
:
int
,
max_timescale
:
int
=
10000
):
super
().
__init__
()
self
.
length
=
length
self
.
channels
=
channels
self
.
max_timescale
=
max_timescale
if
channels
%
2
!=
0
:
raise
ValueError
(
"SinusoidsPositionEmbedding needs even channels input"
)
log_timescale_increment
=
np
.
log
(
max_timescale
)
/
(
channels
//
2
-
1
)
inv_timescales
=
torch
.
exp
(
-
log_timescale_increment
*
torch
.
arange
(
channels
//
2
).
float
()
)
scaled_time
=
(
torch
.
arange
(
length
)[:,
np
.
newaxis
]
*
inv_timescales
[
np
.
newaxis
,
:]
)
positional_embedding
=
torch
.
cat
(
[
torch
.
sin
(
scaled_time
),
torch
.
cos
(
scaled_time
)],
dim
=
1
)
self
.
register_buffer
(
"positional_embedding"
,
positional_embedding
,
persistent
=
False
)
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
return
self
.
positional_embedding
[:
seqlen
,
:]
class
Qwen3OmniMoeAudioAttention
(
nn
.
Module
):
"""Multi-headed attention for Qwen3-Omni Audio Encoder using MMEncoderAttention."""
def
__init__
(
self
,
config
:
Qwen3OmniMoeAudioEncoderConfig
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
num_heads
=
config
.
encoder_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_local_heads
=
self
.
num_heads
//
tp_size
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got `embed_dim`: "
f
"
{
self
.
embed_dim
}
and `num_heads`:
{
self
.
num_heads
}
)."
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv
=
QKVParallelLinear
(
hidden_size
=
self
.
embed_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
num_heads
,
total_num_kv_heads
=
self
.
num_heads
,
bias
=
True
,
prefix
=
f
"
{
prefix
}
.qkv"
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
bias
=
True
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
attn
=
MMEncoderAttention
(
num_heads
=
self
.
num_local_heads
,
head_size
=
self
.
head_dim
,
scale
=
self
.
scaling
,
multimodal_config
=
multimodal_config
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
seq_length
,
_
=
hidden_states
.
size
()
qkv
,
_
=
self
.
qkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
q
=
q
.
view
(
1
,
seq_length
,
-
1
,
self
.
head_dim
)
k
=
k
.
view
(
1
,
seq_length
,
-
1
,
self
.
head_dim
)
v
=
v
.
view
(
1
,
seq_length
,
-
1
,
self
.
head_dim
)
attn_output
=
self
.
attn
(
query
=
q
,
key
=
k
,
value
=
v
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
attn_output
=
attn_output
.
view
(
seq_length
,
-
1
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
Qwen3OmniMoeAudioEncoderLayer
(
nn
.
Module
):
"""Transformer encoder layer for Qwen3-Omni Audio Encoder."""
def
__init__
(
self
,
config
:
Qwen3OmniMoeAudioEncoderConfig
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
Qwen3OmniMoeAudioAttention
(
config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
activation_fn
=
_ACTIVATION_REGISTRY
[
config
.
activation_function
]
self
.
fc1
=
ColumnParallelLinear
(
self
.
embed_dim
,
config
.
encoder_ffn_dim
,
bias
=
True
,
prefix
=
f
"
{
prefix
}
.fc1"
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
encoder_ffn_dim
,
self
.
embed_dim
,
bias
=
True
,
prefix
=
f
"
{
prefix
}
.fc2"
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
"""
Args:
hidden_states: Input tensor of shape (seq_len, hidden_size)
cu_seqlens: Cumulative sequence lengths
max_seqlen: Maximum sequence length in the batch
"""
residual
=
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
# Clamp for numerical stability with fp16
if
hidden_states
.
dtype
==
torch
.
float16
:
clamp_value
=
torch
.
finfo
(
hidden_states
.
dtype
).
max
-
1000
hidden_states
=
torch
.
clamp
(
hidden_states
,
min
=-
clamp_value
,
max
=
clamp_value
)
return
hidden_states
class
Qwen3OmniMoeAudioEncoder
(
nn
.
Module
):
"""vLLM-native Qwen3-Omni Audio Encoder."""
def
__init__
(
self
,
config
:
Qwen3OmniMoeAudioEncoderConfig
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
embed_dim
=
config
.
d_model
self
.
num_mel_bins
=
config
.
num_mel_bins
self
.
max_source_positions
=
config
.
max_source_positions
self
.
n_window
=
config
.
n_window
self
.
n_window_infer
=
config
.
n_window_infer
self
.
conv_chunksize
=
config
.
conv_chunksize
# Position embedding
self
.
positional_embedding
=
SinusoidsPositionEmbedding
(
self
.
max_source_positions
,
embed_dim
)
# Convolutional layers for mel-spectrogram processing
self
.
conv2d1
=
nn
.
Conv2d
(
1
,
config
.
downsample_hidden_size
,
3
,
2
,
padding
=
1
)
self
.
conv2d2
=
nn
.
Conv2d
(
config
.
downsample_hidden_size
,
config
.
downsample_hidden_size
,
3
,
2
,
padding
=
1
,
)
self
.
conv2d3
=
nn
.
Conv2d
(
config
.
downsample_hidden_size
,
config
.
downsample_hidden_size
,
3
,
2
,
padding
=
1
,
)
conv_out_dim
=
config
.
downsample_hidden_size
*
(
(((
config
.
num_mel_bins
+
1
)
//
2
+
1
)
//
2
+
1
)
//
2
)
self
.
conv_out
=
nn
.
Linear
(
conv_out_dim
,
config
.
d_model
,
bias
=
False
)
# Transformer encoder layers
self
.
layers
=
nn
.
ModuleList
(
[
Qwen3OmniMoeAudioEncoderLayer
(
config
,
multimodal_config
=
multimodal_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
i
}
"
,
)
for
i
in
range
(
config
.
encoder_layers
)
]
)
# Output layers
self
.
ln_post
=
nn
.
LayerNorm
(
config
.
d_model
)
self
.
proj1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_model
)
self
.
act
=
_ACTIVATION_REGISTRY
[
config
.
activation_function
]
self
.
proj2
=
nn
.
Linear
(
config
.
d_model
,
config
.
output_dim
)
# Get attention backend
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
config
.
d_model
//
config
.
encoder_attention_heads
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Compute max_seqlen only for flash attention backends."""
max_seqlen
=
None
if
self
.
attn_backend
in
{
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
}:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
return
max_seqlen
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
conv2d1
.
weight
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
conv2d1
.
weight
.
device
def
forward
(
self
,
input_features
:
torch
.
Tensor
,
feature_lens
:
torch
.
Tensor
,
aftercnn_lens
:
torch
.
Tensor
,
):
# Compute chunk information
chunk_num
=
torch
.
ceil
(
feature_lens
/
(
self
.
n_window
*
2
)).
long
()
chunk_lengths
=
torch
.
tensor
(
[
self
.
n_window
*
2
]
*
chunk_num
.
sum
(),
dtype
=
torch
.
long
,
device
=
feature_lens
.
device
,
)
tail_chunk_index
=
F
.
pad
(
chunk_num
,
(
1
,
0
),
value
=-
1
).
cumsum
(
0
)[
1
:]
chunk_lengths
[
tail_chunk_index
]
=
feature_lens
%
(
self
.
n_window
*
2
)
chunk_lengths
[
chunk_lengths
==
0
]
=
self
.
n_window
*
2
# Split input features into chunks and pad
chunk_list
=
input_features
.
T
.
split
(
chunk_lengths
.
tolist
(),
dim
=
0
)
padded_feature
=
nn
.
utils
.
rnn
.
pad_sequence
(
chunk_list
,
batch_first
=
True
).
transpose
(
1
,
2
)
# Compute feature lengths after CNN
feature_lens_after_cnn
=
self
.
_get_cnn_output_lengths
(
chunk_lengths
)
# Vectorized mask creation: avoid creating many small tensors
max_len_after_cnn
=
feature_lens_after_cnn
.
max
().
item
()
indices
=
torch
.
arange
(
max_len_after_cnn
,
device
=
padded_feature
.
device
)
padded_mask_after_cnn
=
indices
.
unsqueeze
(
0
)
<
feature_lens_after_cnn
.
unsqueeze
(
1
)
# Add channel dimension for conv2d
padded_feature
=
padded_feature
.
unsqueeze
(
1
)
# Apply convolutional layers (chunk if needed to avoid OOM)
if
padded_feature
.
size
(
0
)
<=
self
.
conv_chunksize
:
# Fast path: no chunking needed
padded_embed
=
F
.
gelu
(
self
.
conv2d1
(
padded_feature
))
padded_embed
=
F
.
gelu
(
self
.
conv2d2
(
padded_embed
))
padded_embed
=
F
.
gelu
(
self
.
conv2d3
(
padded_embed
))
else
:
# Chunked processing to avoid OOM
padded_embeds
=
[]
for
chunk
in
padded_feature
.
split
(
self
.
conv_chunksize
,
dim
=
0
):
padded_embed
=
F
.
gelu
(
self
.
conv2d1
(
chunk
))
padded_embed
=
F
.
gelu
(
self
.
conv2d2
(
padded_embed
))
padded_embed
=
F
.
gelu
(
self
.
conv2d3
(
padded_embed
))
padded_embeds
.
append
(
padded_embed
)
padded_embed
=
torch
.
cat
(
padded_embeds
,
dim
=
0
)
# (batch, channels, freq, time) -> (batch, time, channels*freq)
b
,
c
,
f
,
t
=
padded_embed
.
size
()
padded_embed
=
self
.
conv_out
(
padded_embed
.
permute
(
0
,
3
,
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
)
)
# Add positional embedding
positional_embedding
=
(
self
.
positional_embedding
.
positional_embedding
[:
padded_embed
.
shape
[
1
],
:]
.
unsqueeze
(
0
)
.
to
(
padded_embed
.
dtype
)
)
padded_embed
=
padded_embed
+
positional_embedding
# Extract valid hidden states and compute cu_seqlens
hidden_states
=
padded_embed
[
padded_mask_after_cnn
]
# Compute cumulative sequence lengths for chunked attention
cu_chunk_lens
=
[
0
]
window_aftercnn
=
padded_mask_after_cnn
.
shape
[
-
1
]
*
(
self
.
n_window_infer
//
(
self
.
n_window
*
2
)
)
# Use tolist() for efficient batch conversion from tensor to Python
for
cnn_len
in
aftercnn_lens
.
tolist
():
num_full_chunks
=
cnn_len
//
window_aftercnn
remainder
=
cnn_len
%
window_aftercnn
cu_chunk_lens
.
extend
([
window_aftercnn
]
*
num_full_chunks
)
if
remainder
:
cu_chunk_lens
.
append
(
remainder
)
cu_seqlens
=
torch
.
tensor
(
cu_chunk_lens
,
device
=
aftercnn_lens
.
device
).
cumsum
(
-
1
,
dtype
=
torch
.
int32
)
max_seqlen
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
# Apply transformer layers
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
,
cu_seqlens
,
max_seqlen
,
)
# Apply output layers
hidden_states
=
self
.
ln_post
(
hidden_states
)
hidden_states
=
self
.
proj1
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
proj2
(
hidden_states
)
return
hidden_states
def
_get_cnn_output_lengths
(
self
,
input_lengths
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute output lengths after the three conv2d layers."""
lengths
=
input_lengths
for
_
in
range
(
3
):
lengths
=
(
lengths
-
1
)
//
2
+
1
return
lengths
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
"""Load weights with mapping from HuggingFace format."""
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"self_attn.qkv."
,
"self_attn.q_proj."
,
"q"
),
(
"self_attn.qkv."
,
"self_attn.k_proj."
,
"k"
),
(
"self_attn.qkv."
,
"self_attn.v_proj."
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
.
get
(
name
)
if
param
is
not
None
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Qwen3_VisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -144,7 +553,7 @@ class Qwen3_VisionPatchEmbed(nn.Module):
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
L
,
C
=
x
.
shape
L
,
_
=
x
.
shape
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
x
=
self
.
proj
(
x
).
view
(
L
,
self
.
hidden_size
)
return
x
...
...
@@ -224,7 +633,7 @@ class Qwen3_VisionBlock(nn.Module):
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
max_seqlen
:
torch
.
Tensor
|
None
,
# Only used for Flash Attention
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
...
...
@@ -1142,12 +1551,11 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
audio_output_lengths
=
_get_feat_extract_output_lengths
(
audio_feature_lengths
)
audio_
output
s
=
self
.
audio_tower
(
audio_
feature
s
=
self
.
audio_tower
(
input_features
.
to
(
self
.
audio_tower
.
dtype
),
feature_lens
=
audio_feature_lengths
,
aftercnn_lens
=
audio_output_lengths
,
)
audio_features
=
audio_outputs
.
last_hidden_state
return
audio_features
.
split
(
audio_output_lengths
.
tolist
())
...
...
@@ -1205,21 +1613,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
self
.
config
=
thinker_config
self
.
multimodal_config
=
multimodal_config
# force "use_flash_attention_2=True" to audio tower to align
# the results.
if
flash_attn
is
not
None
:
audio_config
=
thinker_config
.
audio_config
audio_config
.
_attn_implementation_autoset
=
True
audio_config
.
_attn_implementation
=
"flash_attention_2"
else
:
logger
.
warning
(
"flash_attn is not available, the model may not yield the "
"exactly same result as the transformers implementation "
"in the audio tower part."
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
thinker_config
.
audio_config
,
multimodal_config
=
multimodal_config
,
prefix
=
maybe_prefix
(
prefix
,
"audio_tower"
),
)
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
thinker_config
.
audio_config
)
self
.
visual
=
Qwen3Omni_VisionTransformer
(
vision_config
=
thinker_config
.
vision_config
,
norm_eps
=
getattr
(
thinker_config
.
text_config
,
"rms_norm_eps"
,
1e-6
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment