Unverified Commit 3b23d57c authored by daje0601's avatar daje0601 Committed by GitHub
Browse files

[Model] Add LoRA support for Whisper models (#29856)


Signed-off-by: default avatardaje0601 <englishmt4118@gmail.com>
Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent 2f4226fe
...@@ -289,6 +289,11 @@ def llama32_lora_files(llama32_lora_huggingface_id): ...@@ -289,6 +289,11 @@ def llama32_lora_files(llama32_lora_huggingface_id):
return snapshot_download(repo_id=llama32_lora_huggingface_id) return snapshot_download(repo_id=llama32_lora_huggingface_id)
@pytest.fixture(scope="session")
def whisper_lora_files():
return snapshot_download(repo_id="chengyili2005/whisper-small-mandarin-lora")
@pytest.fixture @pytest.fixture
def reset_default_device(): def reset_default_device():
""" """
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Integration tests for Whisper models with LoRA adapters.
These tests verify that Whisper models can correctly load and use LoRA adapters
for speech-to-text transcription tasks.
"""
import pytest
import vllm
from vllm.assets.audio import AudioAsset
from vllm.lora.request import LoRARequest
from ..utils import create_new_process_for_each_test
# Model configuration
WHISPER_MODEL = "openai/whisper-small"
# Test prompts for Whisper transcription
WHISPER_PROMPT = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
# Note: whisper_lora_files fixture is defined in conftest.py
@pytest.fixture(autouse=True)
def use_spawn_for_whisper(monkeypatch):
"""Whisper has issues with forked workers, use spawn instead."""
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
def create_whisper_llm(enable_lora: bool = True, max_loras: int = 2):
"""Create a Whisper LLM instance with optional LoRA support."""
return vllm.LLM(
model=WHISPER_MODEL,
enable_lora=enable_lora,
max_loras=max_loras if enable_lora else 1,
max_lora_rank=64,
max_model_len=448,
dtype="half",
enforce_eager=True, # For stability in tests
)
def run_whisper_inference(
llm: vllm.LLM,
lora_path: str | None = None,
lora_id: int = 1,
) -> list[str]:
"""Run Whisper inference with optional LoRA adapter."""
# Load test audio
audio_asset = AudioAsset("mary_had_lamb")
audio_data = audio_asset.audio_and_sample_rate
inputs = [
{
"prompt": WHISPER_PROMPT,
"multi_modal_data": {"audio": audio_data},
}
]
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=200,
)
# Prepare LoRA request if adapter path is provided
lora_request = None
if lora_path:
lora_request = LoRARequest(
lora_name=f"whisper_lora_{lora_id}",
lora_int_id=lora_id,
lora_path=lora_path,
)
outputs = llm.generate(inputs, sampling_params, lora_request=lora_request)
return [output.outputs[0].text for output in outputs]
@create_new_process_for_each_test()
def test_whisper_lora_inference(whisper_lora_files):
"""Test basic Whisper inference with a LoRA adapter.
This test verifies that:
1. Whisper model can be loaded with LoRA support enabled
2. A LoRA adapter can be applied during inference
3. The model produces valid transcription output
"""
llm = create_whisper_llm(enable_lora=True)
# Run inference with LoRA
outputs = run_whisper_inference(llm, lora_path=whisper_lora_files, lora_id=1)
# Verify we got a non-empty transcription
assert len(outputs) == 1
assert len(outputs[0]) > 0, "Expected non-empty transcription output"
# The output should contain some recognizable words from the audio
# (Mary had a little lamb)
print(f"Transcription output: {outputs[0]}")
@create_new_process_for_each_test()
def test_whisper_multi_lora(whisper_lora_files):
"""Test Whisper with multiple LoRA adapter IDs.
This test verifies that the same LoRA adapter can be loaded with
different IDs and produce consistent results.
"""
llm = create_whisper_llm(enable_lora=True, max_loras=4)
# Test with different LoRA IDs using the same adapter
outputs_lora1 = run_whisper_inference(llm, lora_path=whisper_lora_files, lora_id=1)
outputs_lora2 = run_whisper_inference(llm, lora_path=whisper_lora_files, lora_id=2)
# Both should produce valid outputs
assert len(outputs_lora1[0]) > 0
assert len(outputs_lora2[0]) > 0
# Same adapter with different IDs should produce same output
assert outputs_lora1 == outputs_lora2, (
f"Expected same outputs for same adapter with different IDs. "
f"Got: {outputs_lora1} vs {outputs_lora2}"
)
@create_new_process_for_each_test()
def test_whisper_with_and_without_lora(whisper_lora_files):
"""Test that Whisper produces different outputs with and without LoRA.
This test verifies that the LoRA adapter actually affects the model output.
"""
llm = create_whisper_llm(enable_lora=True)
# Run with LoRA
outputs_with_lora = run_whisper_inference(
llm, lora_path=whisper_lora_files, lora_id=1
)
# Run without LoRA (base model only)
outputs_without_lora = run_whisper_inference(llm, lora_path=None)
# Both should produce valid outputs
assert len(outputs_with_lora[0]) > 0
assert len(outputs_without_lora[0]) > 0
print(f"Output with LoRA: {outputs_with_lora[0]}")
print(f"Output without LoRA: {outputs_without_lora[0]}")
# Note: Outputs may or may not differ depending on the adapter
# The main verification is that both configurations work
...@@ -49,7 +49,18 @@ class WorkerLoRAManager: ...@@ -49,7 +49,18 @@ class WorkerLoRAManager:
# Use get_text_config() in case of multimodal models # Use get_text_config() in case of multimodal models
text_config = vllm_config.model_config.hf_config.get_text_config() text_config = vllm_config.model_config.hf_config.get_text_config()
self.max_position_embeddings = text_config.max_position_embeddings # For encoder-decoder models (e.g., Whisper), use max_target_positions
# instead of max_position_embeddings
# TODO: Generalize max_position_embeddings handling for
# out-of-tree (OOT) encoder-decoder models
if vllm_config.model_config.is_encoder_decoder:
self.max_position_embeddings = getattr(
text_config, "max_target_positions", None
)
else:
self.max_position_embeddings = getattr(
text_config, "max_position_embeddings", None
)
self.device = device self.device = device
# Lazily initialized by create_lora_manager. # Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager self._adapter_manager: LoRAModelManager
......
...@@ -31,6 +31,7 @@ from vllm.model_executor.layers.attention import ( ...@@ -31,6 +31,7 @@ from vllm.model_executor.layers.attention import (
) )
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
...@@ -66,6 +67,7 @@ from vllm.v1.attention.backend import ( ...@@ -66,6 +67,7 @@ from vllm.v1.attention.backend import (
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsTranscription, SupportsTranscription,
) )
...@@ -279,11 +281,12 @@ class WhisperCrossAttention(WhisperAttention): ...@@ -279,11 +281,12 @@ class WhisperCrossAttention(WhisperAttention):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_proj", prefix=f"{prefix}.q_proj",
) )
self.kv_proj = QKVParallelLinear( # Use MergedColumnParallelLinear for K and V projections.
hidden_size=embed_dim, # This enables LoRA support via MergedColumnParallelLinearWithLoRA
head_size=self.head_dim, # which handles 2-slice configurations.
total_num_heads=0, self.kv_proj = MergedColumnParallelLinear(
total_num_kv_heads=self.total_num_heads, input_size=embed_dim,
output_sizes=[embed_dim, embed_dim],
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_proj", prefix=f"{prefix}.kv_proj",
...@@ -615,8 +618,9 @@ class WhisperModel(nn.Module): ...@@ -615,8 +618,9 @@ class WhisperModel(nn.Module):
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"), (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"), (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"), (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"), # MergedColumnParallelLinear uses integer indices (0, 1)
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"), (".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0),
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
...@@ -790,14 +794,12 @@ class WhisperForConditionalGeneration( ...@@ -790,14 +794,12 @@ class WhisperForConditionalGeneration(
nn.Module, nn.Module,
SupportsTranscription, SupportsTranscription,
SupportsMultiModal, SupportsMultiModal,
SupportsLoRA,
): ):
# LoRA-specific attributes
packed_modules_mapping = { packed_modules_mapping = {
"self_attn.qkv_proj": [ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"self_attn.q_proj", "kv_proj": ["k_proj", "v_proj"],
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
} }
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment