Unverified Commit 8b3f0a99 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Models] Qwen3-ASR (#33312)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
parent 8311f083
...@@ -720,6 +720,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -720,6 +720,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | | `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | | `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | | `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ |
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | T + A<sup>+</sup> | `Qwen/Qwen3-ASR-1.7B` | ✅︎ | ✅︎ |
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | | `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | |
...@@ -769,6 +770,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition. ...@@ -769,6 +770,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ | | `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ |
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | | ✅︎ |
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ | | `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | |
......
...@@ -330,6 +330,25 @@ def run_qwen2_5_omni(question: str, audio_count: int): ...@@ -330,6 +330,25 @@ def run_qwen2_5_omni(question: str, audio_count: int):
) )
def run_qwen3_asr(question: str, audio_count: int) -> ModelRequestData:
model_name = "Qwen/Qwen3-Asr-1.7B"
audio_in_prompt = "<|audio_start|><|audio_pad|><|audio_end|>\n" * audio_count
prompt = f"<|im_start|>user\n{audio_in_prompt}<|im_end|>\n<|im_start|>assistant\n"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Ultravox 0.5-1B # Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int) -> ModelRequestData: def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
...@@ -442,6 +461,7 @@ model_example_map = { ...@@ -442,6 +461,7 @@ model_example_map = {
"phi4_mm": run_phi4mm, "phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio, "qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni, "qwen2_5_omni": run_qwen2_5_omni,
"qwen3_asr": run_qwen3_asr,
"ultravox": run_ultravox, "ultravox": run_ultravox,
"voxtral": run_voxtral, "voxtral": run_voxtral,
"whisper": run_whisper, "whisper": run_whisper,
......
...@@ -944,6 +944,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -944,6 +944,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len=4096, max_model_len=4096,
min_transformers_version="4.57", min_transformers_version="4.57",
), ),
"Qwen3ASRForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3-ASR-1.7B",
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False,
),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True), "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo( "SkyworkR1VChatModel": _HfExamplesInfo(
"Skywork/Skywork-R1V-38B", trust_remote_code=True "Skywork/Skywork-R1V-38B", trust_remote_code=True
......
This diff is collapsed.
...@@ -436,6 +436,10 @@ _MULTIMODAL_MODELS = { ...@@ -436,6 +436,10 @@ _MULTIMODAL_MODELS = {
"qwen3_omni_moe_thinker", "qwen3_omni_moe_thinker",
"Qwen3OmniMoeThinkerForConditionalGeneration", "Qwen3OmniMoeThinkerForConditionalGeneration",
), ),
"Qwen3ASRForConditionalGeneration": (
"qwen3_asr",
"Qwen3ASRForConditionalGeneration",
),
"Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501 "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501
"Qwen3VLMoeForConditionalGeneration": ( "Qwen3VLMoeForConditionalGeneration": (
"qwen3_vl_moe", "qwen3_vl_moe",
......
...@@ -97,6 +97,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( ...@@ -97,6 +97,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
ultravox="UltravoxConfig", ultravox="UltravoxConfig",
step3_vl="Step3VLConfig", step3_vl="Step3VLConfig",
step3_text="Step3TextConfig", step3_text="Step3TextConfig",
qwen3_asr="Qwen3ASRConfig",
qwen3_next="Qwen3NextConfig", qwen3_next="Qwen3NextConfig",
lfm2_moe="Lfm2MoeConfig", lfm2_moe="Lfm2MoeConfig",
tarsier2="Tarsier2Config", tarsier2="Tarsier2Config",
......
...@@ -52,6 +52,7 @@ _CLASS_TO_MODULE: dict[str, str] = { ...@@ -52,6 +52,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"Step3VLConfig": "vllm.transformers_utils.configs.step3_vl", "Step3VLConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl", "Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3TextConfig": "vllm.transformers_utils.configs.step3_vl", "Step3TextConfig": "vllm.transformers_utils.configs.step3_vl",
"Qwen3ASRConfig": "vllm.transformers_utils.configs.qwen3_asr",
"Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next", "Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next",
"Tarsier2Config": "vllm.transformers_utils.configs.tarsier2", "Tarsier2Config": "vllm.transformers_utils.configs.tarsier2",
# Special case: DeepseekV3Config is from HuggingFace Transformers # Special case: DeepseekV3Config is from HuggingFace Transformers
...@@ -94,6 +95,7 @@ __all__ = [ ...@@ -94,6 +95,7 @@ __all__ = [
"Step3VLConfig", "Step3VLConfig",
"Step3VisionEncoderConfig", "Step3VisionEncoderConfig",
"Step3TextConfig", "Step3TextConfig",
"Qwen3ASRConfig",
"Qwen3NextConfig", "Qwen3NextConfig",
"Tarsier2Config", "Tarsier2Config",
] ]
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# mypy: ignore-errors
# coding=utf-8
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import regex as re
import numpy as np
from transformers import AutoProcessor
from transformers.audio_utils import AudioInput
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
from transformers.tokenization_utils_base import TextInput
class Qwen3ASRProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"padding_side": "left",
},
"audio_kwargs": {
"sampling_rate": 16000,
"padding": True,
"return_attention_mask": True,
},
}
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
class Qwen3ASRProcessor(ProcessorMixin):
r"""
Constructs a Qwen3ASR processor.
[`Qwen3ASRProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the
[`~Qwen3ASRProcessor.__call__`] and [`~Qwen3ASRProcessor.decode`] for more information.
Args:
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
The audio feature extractor.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The text tokenizer.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
"""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None):
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
self.audio_token = self.tokenizer.audio_token
self.audio_bos_token = self.tokenizer.audio_bos_token
self.audio_eos_token = self.tokenizer.audio_eos_token
def __call__(
self,
text: TextInput = None,
audio: AudioInput = None,
**kwargs,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to
WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audio (`np.ndarray`, `List[np.ndarray]`):
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
"""
if text is None:
raise ValueError("You need to specify either a `text` input to process.")
output_kwargs = self._merge_kwargs(
Qwen3ASRProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if audio is not None:
output_kwargs["audio_kwargs"]["padding"] = True
output_kwargs["audio_kwargs"]["truncation"] = False
audio_inputs = self.feature_extractor(
audio, **output_kwargs["audio_kwargs"]
)
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
) # rename feature_attention_mask to prevent conflicts later on
audio_inputs["input_features"] = audio_inputs.pop(
"input_features"
) # rename input_features to prevent conflicts later on
audio_lengths = iter(
_get_feat_extract_output_lengths(
audio_inputs["feature_attention_mask"].sum(-1)
)
)
else:
audio_inputs = {}
audio_lengths = iter([])
if not isinstance(text, list):
text = [text]
text = self.replace_multimodal_special_tokens(
text,
audio_lengths,
)
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(
data={**texts_inputs, **audio_inputs},
tensor_type=kwargs.get("return_tensors"),
)
def replace_multimodal_special_tokens(
self,
text,
audio_lengths,
):
processed_text = []
for sample in text:
positions = []
special_tokens = [re.escape(tok) for tok in [self.audio_token]]
pattern = "|".join(special_tokens)
positions = sorted(
[
(match.start(), match.group())
for match in re.finditer(pattern, sample)
]
)
positions.sort(key=lambda x: x[0])
for _, special_token in positions:
if special_token == self.audio_token:
sample = sample.replace(
self.audio_token,
"<|audio_placeholder|>" * next(audio_lengths),
1,
)
sample = sample.replace("<|audio_placeholder|>", self.audio_token)
processed_text.append(sample)
return processed_text
def get_chunked_index(
self, token_indices: np.ndarray, tokens_per_chunk: int
) -> list[tuple[int, int]]:
"""
Splits token index list into chunks based on token value ranges.
Given a list of token indices, returns a list of (start, end) index tuples representing
slices of the list where the token values fall within successive ranges of `tokens_per_chunk`.
For example, if `tokens_per_chunk` is 1000, the function will create chunks such that:
- the first chunk contains token values < 1000,
- the second chunk contains values >= 1000 and < 2000, and so on.
Parameters:
token_indices (`np.ndarray`): A monotonically increasing list of token index values.
tokens_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
Returns:
`list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
and end (exclusive) indices of a chunk in `token_indices`.
"""
def _iter():
i, start_idx = 0, 0 # skip bos token
current_chunk = 1
while i < len(token_indices): # skip eos token
if token_indices[i] >= current_chunk * tokens_per_chunk:
yield (start_idx, i)
start_idx = i
current_chunk += 1
i += 1
yield (start_idx, len(token_indices))
return list(_iter())
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
return super().apply_chat_template(conversations, chat_template, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(
dict.fromkeys(
tokenizer_input_names
+ feature_extractor_input_names
+ ["feature_attention_mask"]
)
)
AutoProcessor.register("Qwen3ASRProcessor", Qwen3ASRProcessor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment