Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c50e6dee
Commit
c50e6dee
authored
Mar 05, 2026
by
weishb
Browse files
add qwen3-asr
parent
e962f483
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1312 additions
and
0 deletions
+1312
-0
vllm/model_executor/models/qwen3_asr.py
vllm/model_executor/models/qwen3_asr.py
+989
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+4
-0
vllm/transformers_utils/configs/qwen3_asr.py
vllm/transformers_utils/configs/qwen3_asr.py
+164
-0
vllm/transformers_utils/processors/__init__.py
vllm/transformers_utils/processors/__init__.py
+2
-0
vllm/transformers_utils/processors/qwen3_asr.py
vllm/transformers_utils/processors/qwen3_asr.py
+153
-0
No files found.
vllm/model_executor/models/qwen3_asr.py
0 → 100644
View file @
c50e6dee
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
c50e6dee
...
@@ -329,6 +329,10 @@ _MULTIMODAL_MODELS = {
...
@@ -329,6 +329,10 @@ _MULTIMODAL_MODELS = {
"granite_speech"
,
"granite_speech"
,
"GraniteSpeechForConditionalGeneration"
,
"GraniteSpeechForConditionalGeneration"
,
),
),
"Qwen3ASRForConditionalGeneration"
:
(
"qwen3_asr"
,
"Qwen3ASRForConditionalGeneration"
,
),
"H2OVLChatModel"
:
(
"h2ovl"
,
"H2OVLChatModel"
),
"H2OVLChatModel"
:
(
"h2ovl"
,
"H2OVLChatModel"
),
"HunYuanVLForConditionalGeneration"
:
(
"HunYuanVLForConditionalGeneration"
:
(
"hunyuan_vision"
,
"hunyuan_vision"
,
...
...
vllm/transformers_utils/configs/qwen3_asr.py
0 → 100644
View file @
c50e6dee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
from
transformers.configuration_utils
import
PretrainedConfig
class
Qwen3ASRAudioEncoderConfig
(
PretrainedConfig
):
model_type
=
"qwen3_asr_audio_encoder"
def
__init__
(
self
,
num_mel_bins
=
128
,
encoder_layers
=
32
,
encoder_attention_heads
=
20
,
encoder_ffn_dim
=
5120
,
d_model
=
1280
,
dropout
=
0
,
attention_dropout
=
0
,
activation_function
=
"gelu"
,
activation_dropout
=
0
,
scale_embedding
=
False
,
initializer_range
=
0.02
,
max_source_positions
=
1500
,
n_window
=
100
,
output_dim
=
3584
,
n_window_infer
=
400
,
conv_chunksize
=
500
,
downsample_hidden_size
=
480
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
num_mel_bins
=
num_mel_bins
self
.
d_model
=
d_model
self
.
encoder_layers
=
encoder_layers
self
.
encoder_attention_heads
=
encoder_attention_heads
self
.
encoder_ffn_dim
=
encoder_ffn_dim
self
.
dropout
=
dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation_function
=
activation_function
self
.
activation_dropout
=
activation_dropout
self
.
num_hidden_layers
=
encoder_layers
self
.
initializer_range
=
initializer_range
self
.
scale_embedding
=
scale_embedding
self
.
max_source_positions
=
max_source_positions
self
.
n_window
=
n_window
self
.
output_dim
=
output_dim
self
.
n_window_infer
=
n_window_infer
self
.
conv_chunksize
=
conv_chunksize
self
.
downsample_hidden_size
=
downsample_hidden_size
class
Qwen3ASRTextConfig
(
PretrainedConfig
):
model_type
=
"qwen3_asr_text"
base_config_key
=
"text_config"
def
__init__
(
self
,
vocab_size
=
151936
,
hidden_size
=
4096
,
intermediate_size
=
22016
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
head_dim
=
128
,
hidden_act
=
"silu"
,
max_position_embeddings
=
128000
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
5000000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_key_value_heads
=
(
num_attention_heads
if
num_key_value_heads
is
None
else
num_key_value_heads
)
self
.
head_dim
=
head_dim
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
)
class
Qwen3ASRThinkerConfig
(
PretrainedConfig
):
model_type
=
"qwen3_asr_thinker"
attribute_map
=
{}
sub_configs
=
{
"audio_config"
:
Qwen3ASRAudioEncoderConfig
,
"text_config"
:
Qwen3ASRTextConfig
,
}
def
__init__
(
self
,
audio_config
=
None
,
text_config
=
None
,
audio_token_id
=
151646
,
audio_start_token_id
=
151647
,
user_token_id
=
872
,
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
user_token_id
=
user_token_id
self
.
audio_start_token_id
=
audio_start_token_id
self
.
initializer_range
=
initializer_range
if
isinstance
(
audio_config
,
dict
):
audio_config
=
Qwen3ASRAudioEncoderConfig
(
**
audio_config
)
elif
audio_config
is
None
:
audio_config
=
Qwen3ASRAudioEncoderConfig
()
self
.
audio_config
=
audio_config
if
isinstance
(
text_config
,
dict
):
text_config
=
Qwen3ASRTextConfig
(
**
text_config
)
elif
text_config
is
None
:
text_config
=
Qwen3ASRTextConfig
()
self
.
text_config
=
text_config
self
.
audio_token_id
=
audio_token_id
class
Qwen3ASRConfig
(
PretrainedConfig
):
model_type
=
"qwen3_asr"
sub_configs
=
{
"thinker_config"
:
Qwen3ASRThinkerConfig
,
}
def
__init__
(
self
,
thinker_config
=
None
,
support_languages
=
None
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
if
thinker_config
is
None
:
thinker_config
=
{}
self
.
thinker_config
=
Qwen3ASRThinkerConfig
(
**
thinker_config
)
self
.
support_languages
=
support_languages
def
get_text_config
(
self
,
decoder
=
False
)
->
"PretrainedConfig"
:
return
self
.
thinker_config
.
get_text_config
()
__all__
=
[
"Qwen3ASRConfig"
,
"Qwen3ASRThinkerConfig"
,
"Qwen3ASRAudioEncoderConfig"
,
]
vllm/transformers_utils/processors/__init__.py
View file @
c50e6dee
...
@@ -14,6 +14,7 @@ from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
...
@@ -14,6 +14,7 @@ from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
from
vllm.transformers_utils.processors.hunyuan_vl_image
import
HunYuanVLImageProcessor
from
vllm.transformers_utils.processors.hunyuan_vl_image
import
HunYuanVLImageProcessor
from
vllm.transformers_utils.processors.ovis
import
OvisProcessor
from
vllm.transformers_utils.processors.ovis
import
OvisProcessor
from
vllm.transformers_utils.processors.ovis2_5
import
Ovis2_5Processor
from
vllm.transformers_utils.processors.ovis2_5
import
Ovis2_5Processor
from
vllm.transformers_utils.processors.qwen3_asr
import
Qwen3ASRProcessor
__all__
=
[
__all__
=
[
"BagelProcessor"
,
"BagelProcessor"
,
...
@@ -22,4 +23,5 @@ __all__ = [
...
@@ -22,4 +23,5 @@ __all__ = [
"HunYuanVLImageProcessor"
,
"HunYuanVLImageProcessor"
,
"OvisProcessor"
,
"OvisProcessor"
,
"Ovis2_5Processor"
,
"Ovis2_5Processor"
,
"Qwen3ASRProcessor"
,
]
]
vllm/transformers_utils/processors/qwen3_asr.py
0 → 100644
View file @
c50e6dee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
import
re
from
typing
import
Iterable
import
numpy
as
np
from
transformers
import
AutoProcessor
from
transformers.audio_utils
import
AudioInput
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.processing_utils
import
ProcessingKwargs
,
ProcessorMixin
from
transformers.tokenization_utils_base
import
TextInput
class
Qwen3ASRProcessorKwargs
(
ProcessingKwargs
,
total
=
False
):
_defaults
=
{
"text_kwargs"
:
{
"padding"
:
False
,
"padding_side"
:
"left"
,
},
"audio_kwargs"
:
{
"sampling_rate"
:
16000
,
"padding"
:
True
,
"return_attention_mask"
:
True
,
},
}
def
_get_feat_extract_output_lengths
(
input_lengths
):
"""
Computes the output length of the convolutional layers and the output
length of the audio encoder.
"""
input_lengths_leave
=
input_lengths
%
100
feat_lengths
=
(
input_lengths_leave
-
1
)
//
2
+
1
output_lengths
=
((
feat_lengths
-
1
)
//
2
+
1
-
1
)
//
2
+
1
+
(
input_lengths
//
100
)
*
13
return
output_lengths
class
Qwen3ASRProcessor
(
ProcessorMixin
):
"""
Composite processor for Qwen3-ASR:
- WhisperFeatureExtractor for audio
- Qwen2 tokenizer family for text
"""
attributes
=
[
"feature_extractor"
,
"tokenizer"
]
feature_extractor_class
=
"WhisperFeatureExtractor"
tokenizer_class
=
(
"Qwen2Tokenizer"
,
"Qwen2TokenizerFast"
)
def
__init__
(
self
,
feature_extractor
=
None
,
tokenizer
=
None
,
chat_template
=
None
):
super
().
__init__
(
feature_extractor
,
tokenizer
,
chat_template
=
chat_template
)
self
.
audio_token
=
self
.
tokenizer
.
audio_token
self
.
audio_bos_token
=
self
.
tokenizer
.
audio_bos_token
self
.
audio_eos_token
=
self
.
tokenizer
.
audio_eos_token
def
__call__
(
self
,
text
:
TextInput
=
None
,
audio
:
AudioInput
=
None
,
**
kwargs
,
)
->
BatchFeature
:
if
text
is
None
:
raise
ValueError
(
"You need to specify either a `text` input to process."
)
output_kwargs
=
self
.
_merge_kwargs
(
Qwen3ASRProcessorKwargs
,
tokenizer_init_kwargs
=
self
.
tokenizer
.
init_kwargs
,
**
kwargs
,
)
if
audio
is
not
None
:
output_kwargs
[
"audio_kwargs"
][
"padding"
]
=
True
output_kwargs
[
"audio_kwargs"
][
"truncation"
]
=
False
audio_inputs
=
self
.
feature_extractor
(
audio
,
**
output_kwargs
[
"audio_kwargs"
])
audio_inputs
[
"feature_attention_mask"
]
=
audio_inputs
.
pop
(
"attention_mask"
)
audio_lengths
=
iter
(
_get_feat_extract_output_lengths
(
audio_inputs
[
"feature_attention_mask"
].
sum
(
-
1
)
)
)
else
:
audio_inputs
=
{}
audio_lengths
=
iter
([])
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
text
=
self
.
replace_multimodal_special_tokens
(
text
,
audio_lengths
)
text_inputs
=
self
.
tokenizer
(
text
,
**
output_kwargs
[
"text_kwargs"
])
return
BatchFeature
(
data
=
{
**
text_inputs
,
**
audio_inputs
},
tensor_type
=
kwargs
.
get
(
"return_tensors"
),
)
def
replace_multimodal_special_tokens
(
self
,
text
:
list
[
str
],
audio_lengths
:
Iterable
[
int
]
)
->
list
[
str
]:
processed_text
:
list
[
str
]
=
[]
for
sample
in
text
:
pattern
=
re
.
escape
(
self
.
audio_token
)
positions
=
sorted
(
(
match
.
start
(),
match
.
group
())
for
match
in
re
.
finditer
(
pattern
,
sample
)
)
for
_
,
special_token
in
positions
:
if
special_token
==
self
.
audio_token
:
sample
=
sample
.
replace
(
self
.
audio_token
,
"<|audio_placeholder|>"
*
int
(
next
(
audio_lengths
)),
1
,
)
sample
=
sample
.
replace
(
"<|audio_placeholder|>"
,
self
.
audio_token
)
processed_text
.
append
(
sample
)
return
processed_text
def
get_chunked_index
(
self
,
token_indices
:
np
.
ndarray
,
tokens_per_chunk
:
int
)
->
list
[
tuple
[
int
,
int
]]:
def
_iter
():
i
,
start_idx
=
0
,
0
current_chunk
=
1
while
i
<
len
(
token_indices
):
if
token_indices
[
i
]
>=
current_chunk
*
tokens_per_chunk
:
yield
(
start_idx
,
i
)
start_idx
=
i
current_chunk
+=
1
i
+=
1
yield
(
start_idx
,
len
(
token_indices
))
return
list
(
_iter
())
@
property
def
model_input_names
(
self
):
tokenizer_input_names
=
self
.
tokenizer
.
model_input_names
feature_extractor_input_names
=
self
.
feature_extractor
.
model_input_names
return
list
(
dict
.
fromkeys
(
tokenizer_input_names
+
feature_extractor_input_names
+
[
"feature_attention_mask"
]
)
)
AutoProcessor
.
register
(
"Qwen3ASRProcessor"
,
Qwen3ASRProcessor
)
__all__
=
[
"Qwen3ASRProcessor"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment