Unverified Commit 37e8182b authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[v1] Add Whisper model support (encoder-decoder) (#21088)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarNickLucche <nlucches@redhat.com>
parent 4db44264
...@@ -321,7 +321,6 @@ steps: ...@@ -321,7 +321,6 @@ steps:
- python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_pooling.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder.py
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
- python3 offline_inference/basic/classify.py - python3 offline_inference/basic/classify.py
- python3 offline_inference/basic/embed.py - python3 offline_inference/basic/embed.py
...@@ -644,7 +643,7 @@ steps: ...@@ -644,7 +643,7 @@ steps:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pip freeze | grep -E 'torch' - pip freeze | grep -E 'torch'
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Models Test (Extended) 1 - label: Multi-Modal Models Test (Extended) 1
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
...@@ -818,7 +817,8 @@ steps: ...@@ -818,7 +817,8 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
# test sequence parallel # test sequence parallel
- pytest -v -s distributed/test_sequence_parallel.py - pytest -v -s distributed/test_sequence_parallel.py
# this test fails consistently. # this test fails consistently.
......
...@@ -5,6 +5,8 @@ Demonstrate prompting of text-to-text ...@@ -5,6 +5,8 @@ Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART and mBART. encoder/decoder models, specifically BART and mBART.
This script is refactored to allow model selection via command-line arguments. This script is refactored to allow model selection via command-line arguments.
NOTE: This example is not yet supported in V1.
""" """
import argparse import argparse
......
...@@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with ...@@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with
the explicit/implicit prompt format on enc-dec LMMs for text generation. the explicit/implicit prompt format on enc-dec LMMs for text generation.
""" """
import os
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import asdict from dataclasses import asdict
...@@ -130,6 +131,8 @@ def run_mllama(): ...@@ -130,6 +131,8 @@ def run_mllama():
def run_whisper(): def run_whisper():
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
engine_args = EngineArgs( engine_args = EngineArgs(
model="openai/whisper-large-v3-turbo", model="openai/whisper-large-v3-turbo",
max_model_len=448, max_model_len=448,
......
...@@ -63,6 +63,7 @@ def clear_cache(): ...@@ -63,6 +63,7 @@ def clear_cache():
current_platform.is_cpu(), current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models" reason="CPU backend is not currently supported with encoder/decoder models"
) )
@pytest.mark.skip(reason="bart not supported in V1")
def test_encoder_decoder_e2e( def test_encoder_decoder_e2e(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
......
...@@ -30,6 +30,7 @@ async def client(server): ...@@ -30,6 +30,7 @@ async def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="bart is not yet supported in V1")
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name, completion = await client.completions.create(model=model_name,
prompt="Hello, my name is", prompt="Hello, my name is",
......
...@@ -178,6 +178,7 @@ def run_test( ...@@ -178,6 +178,7 @@ def run_test(
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.skip(reason="bart not supported in V1")
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
...@@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, ...@@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
@pytest.mark.skip(reason="bart not supported in V1")
def test_models_distributed(hf_runner, vllm_runner, def test_models_distributed(hf_runner, vllm_runner,
example_encoder_decoder_prompts, example_encoder_decoder_prompts,
distributed_executor_backend, model, dtype, distributed_executor_backend, model, dtype,
......
...@@ -122,8 +122,7 @@ def run_test( ...@@ -122,8 +122,7 @@ def run_test(
@pytest.mark.core_model @pytest.mark.core_model
@pytest.mark.parametrize( @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
"model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"])
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_models(vllm_runner, model) -> None: def test_models(vllm_runner, model) -> None:
run_test( run_test(
......
...@@ -31,6 +31,7 @@ from ...utils import dummy_hf_overrides ...@@ -31,6 +31,7 @@ from ...utils import dummy_hf_overrides
ARCH_TO_SKIP = { ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements", "MolmoForCausalLM": "incompatible requirements",
"Florence2ForConditionalGeneration": "not supported in V1",
} }
ARCH_NEEDS_EXTRAS = [ ARCH_NEEDS_EXTRAS = [
"InternVLChatModel", "InternVLChatModel",
......
...@@ -68,6 +68,12 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, ...@@ -68,6 +68,12 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3. # L4 supports FA3.
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
if model_arch == "Florence2ForConditionalGeneration":
# An encoder-decoder model that's V0-only. Just skip it
# since V0 is about to be removed.
pytest.skip("Skipping Florence2ForConditionalGeneration")
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
LLM( LLM(
model_info.default, model_info.default,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,
......
...@@ -10,7 +10,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -10,7 +10,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [ UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder "facebook/bart-large-cnn", # encoder decoder
] ]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import numpy as np
import torch
from transformers import CacheConfig
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
from vllm.v1.kv_cache_interface import CrossAttentionSpec
logger = init_logger(__name__)
def _get_max_encoder_len(vllm_config: VllmConfig) -> int:
return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(
vllm_config.model_config)
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
kv_cache_spec: CrossAttentionSpec,
device: torch.device) -> torch.Tensor:
"""Get cross-attention slot mappings."""
block_size = kv_cache_spec.block_size
slot_mappings = []
# Find indices with non-zero encoder sequence lengths
# The majority of parallel requests will be running the
# decoder, so this list should be relatively small.
active_indices = np.nonzero(encoder_seq_lens)[0]
for req_index in active_indices:
encoder_seq_len = encoder_seq_lens[req_index].item()
# Calculate the number of blocks needed for this request
num_blocks_needed = cdiv(encoder_seq_len, block_size)
# Get the block IDs for this request from the tensor
req_block_ids = block_table_tensor[req_index]
# Get only the blocks we need (first num_blocks_needed blocks)
needed_block_ids = req_block_ids[:num_blocks_needed]
# All needed blocks are allocated
i_values = torch.arange(encoder_seq_len,
dtype=torch.int64,
device=device)
block_indices = i_values // block_size
block_offsets = i_values % block_size
block_numbers = needed_block_ids[block_indices]
slot_mapping = block_numbers * block_size + block_offsets
slot_mappings.append(slot_mapping)
if slot_mappings:
return torch.cat(slot_mappings)
else:
return torch.empty(0, dtype=torch.int64, device=device)
@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
max_encoder_len = _get_max_encoder_len(self.vllm_config)
new_metadata.max_seq_len = max_encoder_len
new_metadata.seq_lens = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device=self.device,
)
new_metadata.seq_lens_cpu = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device="cpu",
)
new_metadata.slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
self.kv_cache_spec, self.device)
return super().build(common_prefix_len, new_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder)
return attn_backend
class CrossAttention(Attention):
"""
Cross-attention for encoder-decoder models.
Handles attention between decoder queries and encoder keys/values.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_cross_attention_backend(
underlying_attn_backend)
else:
# in v0 cross attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
"CrossAttention only supports AttentionType.ENCODER_DECODER")
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_DECODER,
**kwargs)
...@@ -8,6 +8,7 @@ import enum ...@@ -8,6 +8,7 @@ import enum
import hashlib import hashlib
import inspect import inspect
import json import json
import os
import textwrap import textwrap
import warnings import warnings
from collections.abc import Mapping from collections.abc import Mapping
...@@ -41,6 +42,7 @@ from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy ...@@ -41,6 +42,7 @@ from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config, ConfigFormat, get_config, get_hf_image_processor_config,
...@@ -3509,16 +3511,33 @@ class VllmConfig: ...@@ -3509,16 +3511,33 @@ class VllmConfig:
disable_chunked_prefill_reasons: list[str] = [] disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config: if self.model_config:
if self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last": if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append( disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked " "Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.") "prefill and prefix caching; disabling both.")
elif not getattr(self.model_config.hf_config, "is_causal", True): elif self.model_config.is_encoder_decoder:
self.scheduler_config.max_num_encoder_input_tokens = \
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens)
self.scheduler_config.disable_chunked_mm_input = True
disable_chunked_prefill_reasons.append( disable_chunked_prefill_reasons.append(
"Only models using causal attention supports chunked " "Encoder-decoder models do not support chunked prefill nor"
"prefill and prefix caching; disabling both.") " prefix caching; disabling both.")
if (self.model_config.architecture
== "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
!= "spawn"):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'.")
if disable_chunked_prefill_reasons: if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons: for reason in disable_chunked_prefill_reasons:
......
...@@ -600,7 +600,6 @@ class VoxtralEncoderModel(nn.Module): ...@@ -600,7 +600,6 @@ class VoxtralEncoderModel(nn.Module):
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config, self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
prefix=maybe_prefix( prefix=maybe_prefix(
prefix, "whisper_encoder"), prefix, "whisper_encoder"),
is_standalone_encoder=True,
init_in_fp32=True) init_in_fp32=True)
mel_filters = mel_filter_bank( mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.window_size // 2, num_frequency_bins=1 + self.config.window_size // 2,
......
...@@ -15,6 +15,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids ...@@ -15,6 +15,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig) VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -43,7 +44,7 @@ from vllm.transformers_utils.processor import cached_get_processor ...@@ -43,7 +44,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only) SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers) make_layers)
...@@ -124,6 +125,34 @@ class WhisperAudioInputs(TensorSchema): ...@@ -124,6 +125,34 @@ class WhisperAudioInputs(TensorSchema):
TensorShape("b", "nmb", "t")] TensorShape("b", "nmb", "t")]
class WhisperEncoderAttention(MultiHeadAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""
Input shape: batch_size x seq_len x hidden_size
or seq_len x hidden_size
"""
is_2d = query.dim() == 2
if is_2d:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
# Call the parent forward method
out = super().forward(query, key, value)
if is_2d:
out = out.squeeze(0)
return out
class WhisperPositionalEmbedding(nn.Embedding): class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int): def __init__(self, num_positions: int, embedding_dim: int):
...@@ -144,7 +173,6 @@ class WhisperAttention(nn.Module): ...@@ -144,7 +173,6 @@ class WhisperAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
standalone_encoder: bool = False,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -180,14 +208,25 @@ class WhisperAttention(nn.Module): ...@@ -180,14 +208,25 @@ class WhisperAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
if standalone_encoder: if attn_type == AttentionType.ENCODER:
self.attn = MultiHeadAttention( self.attn = WhisperEncoderAttention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
else: elif self.attn_type == AttentionType.ENCODER_DECODER:
self.attn = CrossAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
else: # AttentionType.DECODER (regular decoder self-attention)
self.attn = Attention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
...@@ -332,11 +371,7 @@ class WhisperMLP(nn.Module): ...@@ -332,11 +371,7 @@ class WhisperMLP(nn.Module):
class WhisperEncoderLayer(nn.Module): class WhisperEncoderLayer(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -350,7 +385,6 @@ class WhisperEncoderLayer(nn.Module): ...@@ -350,7 +385,6 @@ class WhisperEncoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
standalone_encoder=is_standalone_encoder,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.mlp = WhisperMLP( self.mlp = WhisperMLP(
...@@ -446,12 +480,10 @@ class WhisperEncoder(nn.Module): ...@@ -446,12 +480,10 @@ class WhisperEncoder(nn.Module):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
is_standalone_encoder: bool = False,
init_in_fp32: bool = False): init_in_fp32: bool = False):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
embed_dim = config.d_model embed_dim = config.d_model
self.is_standalone_encoder = is_standalone_encoder
self.num_mel_bins = config.num_mel_bins self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim) self.embed_scale = (math.sqrt(embed_dim)
...@@ -469,9 +501,7 @@ class WhisperEncoder(nn.Module): ...@@ -469,9 +501,7 @@ class WhisperEncoder(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers, config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers"),
is_standalone_encoder=
is_standalone_encoder),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
...@@ -752,7 +782,7 @@ class WhisperMultiModalProcessor( ...@@ -752,7 +782,7 @@ class WhisperMultiModalProcessor(
info=WhisperProcessingInfo, info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder) dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal, SupportsV0Only): SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
"self_attn.qkv_proj": [ "self_attn.qkv_proj": [
"self_attn.q_proj", "self_attn.q_proj",
...@@ -880,19 +910,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -880,19 +910,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_multimodal_embeddings(self, def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings: **kwargs: object) -> MultiModalEmbeddings:
# TODO: This method does not obey the interface for SupportsMultiModal. # Required as part of SupportsMultiModal interface.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs(audio_input["input_features"]) return [self.model.get_encoder_outputs(audio_input["input_features"])]
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since # This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once # Whisper does not have encoder text tokens.
# encoder/decoder support is implemented in V1.
return self.model.decoder.get_input_embeddings(input_ids) return self.model.decoder.get_input_embeddings(input_ids)
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(
......
...@@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: ...@@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
encoder_attention_heads=encoder_args["n_heads"], encoder_attention_heads=encoder_args["n_heads"],
vocab_size=encoder_args["vocab_size"], vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"], max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
) )
} }
if quant_config: if quant_config:
......
...@@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): ...@@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None: vllm_config: VllmConfig, device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
# For reorder # For reorder
......
...@@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder( ...@@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config) self.parallel_config)
......
...@@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
self.device = device super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape) self._decode_wrapper = None # Wrapper for decode (general shape)
......
...@@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder( ...@@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config) self.parallel_config)
......
...@@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder( ...@@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec) assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment