"vscode:/vscode.git/clone" did not exist on "dff0a2b39475096f5456721bfc8df3c7fea3cc57"
Unverified Commit 37e8182b authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[v1] Add Whisper model support (encoder-decoder) (#21088)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
parent 4db44264
......@@ -321,7 +321,6 @@ steps:
- python3 offline_inference/vision_language_pooling.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder.py
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
- python3 offline_inference/basic/classify.py
- python3 offline_inference/basic/embed.py
......@@ -644,7 +643,7 @@ steps:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pip freeze | grep -E 'torch'
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Models Test (Extended) 1
mirror_hardwares: [amdexperimental]
......@@ -818,7 +817,8 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
# test sequence parallel
- pytest -v -s distributed/test_sequence_parallel.py
# this test fails consistently.
......
......@@ -5,6 +5,8 @@ Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART and mBART.
This script is refactored to allow model selection via command-line arguments.
NOTE: This example is not yet supported in V1.
"""
import argparse
......
......@@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation.
"""
import os
import time
from collections.abc import Sequence
from dataclasses import asdict
......@@ -130,6 +131,8 @@ def run_mllama():
def run_whisper():
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
engine_args = EngineArgs(
model="openai/whisper-large-v3-turbo",
max_model_len=448,
......
......@@ -63,6 +63,7 @@ def clear_cache():
current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models"
)
@pytest.mark.skip(reason="bart not supported in V1")
def test_encoder_decoder_e2e(
hf_runner,
vllm_runner,
......
......@@ -30,6 +30,7 @@ async def client(server):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="bart is not yet supported in V1")
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
......
......@@ -178,6 +178,7 @@ def run_test(
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.skip(reason="bart not supported in V1")
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
......@@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
@pytest.mark.skip(reason="bart not supported in V1")
def test_models_distributed(hf_runner, vllm_runner,
example_encoder_decoder_prompts,
distributed_executor_backend, model, dtype,
......
......@@ -122,8 +122,7 @@ def run_test(
@pytest.mark.core_model
@pytest.mark.parametrize(
"model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@create_new_process_for_each_test()
def test_models(vllm_runner, model) -> None:
run_test(
......
......@@ -31,6 +31,7 @@ from ...utils import dummy_hf_overrides
ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"Florence2ForConditionalGeneration": "not supported in V1",
}
ARCH_NEEDS_EXTRAS = [
"InternVLChatModel",
......
......@@ -68,6 +68,12 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3.
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
if model_arch == "Florence2ForConditionalGeneration":
# An encoder-decoder model that's V0-only. Just skip it
# since V0 is about to be removed.
pytest.skip("Skipping Florence2ForConditionalGeneration")
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
......
......@@ -10,7 +10,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import numpy as np
import torch
from transformers import CacheConfig
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
from vllm.v1.kv_cache_interface import CrossAttentionSpec
logger = init_logger(__name__)
def _get_max_encoder_len(vllm_config: VllmConfig) -> int:
return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(
vllm_config.model_config)
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
kv_cache_spec: CrossAttentionSpec,
device: torch.device) -> torch.Tensor:
"""Get cross-attention slot mappings."""
block_size = kv_cache_spec.block_size
slot_mappings = []
# Find indices with non-zero encoder sequence lengths
# The majority of parallel requests will be running the
# decoder, so this list should be relatively small.
active_indices = np.nonzero(encoder_seq_lens)[0]
for req_index in active_indices:
encoder_seq_len = encoder_seq_lens[req_index].item()
# Calculate the number of blocks needed for this request
num_blocks_needed = cdiv(encoder_seq_len, block_size)
# Get the block IDs for this request from the tensor
req_block_ids = block_table_tensor[req_index]
# Get only the blocks we need (first num_blocks_needed blocks)
needed_block_ids = req_block_ids[:num_blocks_needed]
# All needed blocks are allocated
i_values = torch.arange(encoder_seq_len,
dtype=torch.int64,
device=device)
block_indices = i_values // block_size
block_offsets = i_values % block_size
block_numbers = needed_block_ids[block_indices]
slot_mapping = block_numbers * block_size + block_offsets
slot_mappings.append(slot_mapping)
if slot_mappings:
return torch.cat(slot_mappings)
else:
return torch.empty(0, dtype=torch.int64, device=device)
@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
max_encoder_len = _get_max_encoder_len(self.vllm_config)
new_metadata.max_seq_len = max_encoder_len
new_metadata.seq_lens = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device=self.device,
)
new_metadata.seq_lens_cpu = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device="cpu",
)
new_metadata.slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
self.kv_cache_spec, self.device)
return super().build(common_prefix_len, new_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder)
return attn_backend
class CrossAttention(Attention):
"""
Cross-attention for encoder-decoder models.
Handles attention between decoder queries and encoder keys/values.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_cross_attention_backend(
underlying_attn_backend)
else:
# in v0 cross attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
"CrossAttention only supports AttentionType.ENCODER_DECODER")
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_DECODER,
**kwargs)
......@@ -8,6 +8,7 @@ import enum
import hashlib
import inspect
import json
import os
import textwrap
import warnings
from collections.abc import Mapping
......@@ -41,6 +42,7 @@ from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
......@@ -3509,16 +3511,33 @@ class VllmConfig:
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
elif not getattr(self.model_config.hf_config, "is_causal", True):
if self.model_config:
if self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
elif self.model_config.is_encoder_decoder:
self.scheduler_config.max_num_encoder_input_tokens = \
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens)
self.scheduler_config.disable_chunked_mm_input = True
disable_chunked_prefill_reasons.append(
"Only models using causal attention supports chunked "
"prefill and prefix caching; disabling both.")
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both.")
if (self.model_config.architecture
== "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
!= "spawn"):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
......
......@@ -600,7 +600,6 @@ class VoxtralEncoderModel(nn.Module):
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "whisper_encoder"),
is_standalone_encoder=True,
init_in_fp32=True)
mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.window_size // 2,
......
......@@ -15,6 +15,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -43,7 +44,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only)
SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers)
......@@ -124,6 +125,34 @@ class WhisperAudioInputs(TensorSchema):
TensorShape("b", "nmb", "t")]
class WhisperEncoderAttention(MultiHeadAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""
Input shape: batch_size x seq_len x hidden_size
or seq_len x hidden_size
"""
is_2d = query.dim() == 2
if is_2d:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
# Call the parent forward method
out = super().forward(query, key, value)
if is_2d:
out = out.squeeze(0)
return out
class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int):
......@@ -144,7 +173,6 @@ class WhisperAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
standalone_encoder: bool = False,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -180,14 +208,25 @@ class WhisperAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
if standalone_encoder:
self.attn = MultiHeadAttention(
if attn_type == AttentionType.ENCODER:
self.attn = WhisperEncoderAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
else:
elif self.attn_type == AttentionType.ENCODER_DECODER:
self.attn = CrossAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
else: # AttentionType.DECODER (regular decoder self-attention)
self.attn = Attention(
self.num_heads,
self.head_dim,
......@@ -332,11 +371,7 @@ class WhisperMLP(nn.Module):
class WhisperEncoderLayer(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
......@@ -350,7 +385,6 @@ class WhisperEncoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
standalone_encoder=is_standalone_encoder,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.mlp = WhisperMLP(
......@@ -446,12 +480,10 @@ class WhisperEncoder(nn.Module):
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False,
init_in_fp32: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
self.is_standalone_encoder = is_standalone_encoder
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim)
......@@ -469,9 +501,7 @@ class WhisperEncoder(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers",
is_standalone_encoder=
is_standalone_encoder),
prefix=f"{prefix}.layers"),
prefix=f"{prefix}.layers",
)
self.layer_norm = nn.LayerNorm(config.d_model)
......@@ -752,7 +782,7 @@ class WhisperMultiModalProcessor(
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal, SupportsV0Only):
SupportsMultiModal):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
......@@ -880,19 +910,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs(audio_input["input_features"])
return [self.model.get_encoder_outputs(audio_input["input_features"])]
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once
# encoder/decoder support is implemented in V1.
# This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens.
return self.model.decoder.get_input_embeddings(input_ids)
def _parse_and_validate_audio_input(
......
......@@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
encoder_attention_heads=encoder_args["n_heads"],
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
)
}
if quant_config:
......
......@@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
# For reorder
......
......@@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
......
......@@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)
......
......@@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
......
......@@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self,
common_prefix_len: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment