Unverified Commit 8d0f908b authored by Peter Nguyen's avatar Peter Nguyen Committed by GitHub
Browse files

[Model] Implement LoRA support for Qwen3ASRForConditionalGeneration (#37247)


Signed-off-by: default avatarPeter Nguyen <petern0408@gmail.com>
parent c9dddc14
...@@ -666,7 +666,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition. ...@@ -666,7 +666,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ | | `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ |
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | | ✅︎ | | `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | ✅︎ | ✅︎ |
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, etc. | | ✅︎ | | `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, etc. | | ✅︎ |
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ | | `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | |
......
...@@ -37,6 +37,7 @@ from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensProm ...@@ -37,6 +37,7 @@ from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensProm
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE, SupportsMRoPE,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
...@@ -266,7 +267,21 @@ class Qwen3ASRForConditionalGeneration( ...@@ -266,7 +267,21 @@ class Qwen3ASRForConditionalGeneration(
SupportsPP, SupportsPP,
SupportsMRoPE, SupportsMRoPE,
SupportsTranscription, SupportsTranscription,
SupportsLoRA,
): ):
# LoRA support
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
supported_languages = ISO639_1_SUPPORTED_LANGS supported_languages = ISO639_1_SUPPORTED_LANGS
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
...@@ -513,6 +528,17 @@ class Qwen3ASRForConditionalGeneration( ...@@ -513,6 +528,17 @@ class Qwen3ASRForConditionalGeneration(
tower_model=["audio_tower."], tower_model=["audio_tower."],
) )
def get_num_mm_encoder_tokens(self, num_audio_tokens: int) -> int:
"""Return the number of tokens processed by the audio tower encoder.
Required for LoRA support on the tower module.
"""
# For Qwen3-ASR, the audio tower produces one embedding per audio
# placeholder token inserted into the prompt (no additional
# merge/downsample step like vision towers). Therefore, the encoder
# token budget is identity.
return num_audio_tokens
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str cls, model_config: ModelConfig, task_type: str
......
...@@ -57,6 +57,7 @@ from vllm.model_executor.layers.conv import Conv3dLayer ...@@ -57,6 +57,7 @@ from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -357,7 +358,13 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): ...@@ -357,7 +358,13 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
conv_out_dim = config.downsample_hidden_size * ( conv_out_dim = config.downsample_hidden_size * (
(((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2 (((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2
) )
self.conv_out = nn.Linear(conv_out_dim, config.d_model, bias=False) self.conv_out = ReplicatedLinear(
conv_out_dim,
config.d_model,
bias=False,
return_bias=False,
prefix=f"{prefix}.conv_out",
)
# Transformer encoder layers # Transformer encoder layers
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
...@@ -372,9 +379,21 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): ...@@ -372,9 +379,21 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
# Output layers # Output layers
self.ln_post = nn.LayerNorm(config.d_model) self.ln_post = nn.LayerNorm(config.d_model)
self.proj1 = nn.Linear(config.d_model, config.d_model) self.proj1 = ReplicatedLinear(
config.d_model,
config.d_model,
bias=True,
return_bias=False,
prefix=f"{prefix}.proj1",
)
self.act = _ACTIVATION_REGISTRY[config.activation_function] self.act = _ACTIVATION_REGISTRY[config.activation_function]
self.proj2 = nn.Linear(config.d_model, config.output_dim) self.proj2 = ReplicatedLinear(
config.d_model,
config.output_dim,
bias=True,
return_bias=False,
prefix=f"{prefix}.proj2",
)
# Get attention backend # Get attention backend
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
......
...@@ -2783,7 +2783,20 @@ class GPUModelRunner( ...@@ -2783,7 +2783,20 @@ class GPUModelRunner(
) )
self.lora_manager.set_active_adapters(lora_requests, tower_mapping) self.lora_manager.set_active_adapters(lora_requests, tower_mapping)
if hasattr(self.model, "get_num_mm_connector_tokens"): # Only set connector mapping if the model actually has a connector.
# Some multimodal models inherit a stub `get_num_mm_connector_tokens`
# from `SupportsMultiModal`, which returns None and should not be
# treated as a signal that connector LoRA is supported.
mm_mapping = (
self.model.get_mm_mapping() # type: ignore[attr-defined]
if hasattr(self.model, "get_mm_mapping")
else None
)
if (
mm_mapping is not None
and mm_mapping.connector
and hasattr(self.model, "get_num_mm_connector_tokens")
):
post_op_counts = [ post_op_counts = [
self.model.get_num_mm_connector_tokens(num_tokens) # type: ignore[attr-defined] self.model.get_num_mm_connector_tokens(num_tokens) # type: ignore[attr-defined]
for num_tokens in encoder_token_counts for num_tokens in encoder_token_counts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment