Unverified Commit ba5106e5 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[LMM] Implement merged multimodal processor for whisper (#13278)

parent d5ca2110
...@@ -83,11 +83,11 @@ def _test_processing_correctness( ...@@ -83,11 +83,11 @@ def _test_processing_correctness(
} }
tokenizer_encode_kwargs = {} tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type == "mllama": if model_config.hf_config.model_type in ("mllama", "whisper"):
# For Mllama, tokenizer will always add bos_token at the beginning of # For some encoder-decoder models, tokenizer will always add bos_token
# prompt by default, causing hf_processor outputs incorrect token ids. # at the beginning of prompt by default, causing hf_processor outputs
# So we need use `add_special_tokens=False` here to leave bos_token # incorrect token ids. So we need use `add_special_tokens=False` here
# to be added by the processor. # to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False} tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
...@@ -173,6 +173,7 @@ def _test_processing_correctness( ...@@ -173,6 +173,7 @@ def _test_processing_correctness(
"Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_5-llama-3_2-1b", "fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
]) ])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("num_batches", [32])
......
...@@ -4,15 +4,15 @@ import math ...@@ -4,15 +4,15 @@ import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union) Union)
import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput ...@@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
NestedTensors) from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.audio import resample_audio from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
from vllm.sequence import SequenceData MultiModalDataParser)
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription from .interfaces import SupportsMultiModal, SupportsTranscription
from .utils import AutoWeightsLoader, WeightsMapper, make_layers from .utils import AutoWeightsLoader, WeightsMapper, make_layers
...@@ -571,72 +574,126 @@ class WhisperModel(nn.Module): ...@@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
return loaded_params return loaded_params
def get_max_whisper_audio_tokens(ctx: InputContext) -> int: class WhisperProcessingInfo(BaseProcessingInfo):
return ctx.model_config.hf_config.max_source_positions
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]): def get_hf_processor(self,
assert mm_counts["audio"] == 1 sampling_rate: Optional[int] = None
num_tokens = get_max_whisper_audio_tokens(ctx) ) -> WhisperProcessor:
processor = cached_processor_from_config(ctx.model_config) return self.ctx.get_hf_processor(WhisperProcessor)
chunk_length = processor.feature_extractor.chunk_length
sampling_rate = processor.feature_extractor.sampling_rate def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
num_samples = chunk_length * sampling_rate return {"audio": 1}
return DummyData(
SequenceData.from_prompt_token_counts((0, num_tokens)), def get_feature_extractor(self) -> WhisperFeatureExtractor:
{"audio": [(np.zeros(num_samples), sampling_rate)]}, hf_processor = self.get_hf_processor()
) feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def input_processor_for_whisper(ctx: InputContext, inputs):
multi_modal_data = inputs["encoder"]["multi_modal_data"] def get_max_audio_tokens(self) -> int:
if isinstance(multi_modal_data["audio"], list): return self.get_hf_config().max_source_positions
assert len(multi_modal_data["audio"]) == 1
multi_modal_data["audio"] = multi_modal_data["audio"][0] def get_mm_max_tokens_per_item(
# Resample and process audio self,
audio, orig_sr = multi_modal_data["audio"] seq_len: int,
processor = cached_processor_from_config(ctx.model_config) mm_counts: Mapping[str, int],
target_sr = processor.feature_extractor.sampling_rate ) -> Mapping[str, int]:
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) return {"audio": self.get_max_audio_tokens()}
multi_modal_data["audio"] = (audio, target_sr)
# Pre-allocate placeholder tokens in encoder sequence
num_tokens = get_max_whisper_audio_tokens(ctx) class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens
return inputs def get_dummy_processor_inputs(
self,
seq_len: int,
def input_mapper_for_whisper( mm_counts: Mapping[str, int],
ctx: InputContext, ) -> ProcessorInputs:
multi_modal_data: Union[np.ndarray, List[np.ndarray]], feature_extractor = self.info.get_feature_extractor()
) -> MultiModalKwargs:
if not isinstance(multi_modal_data, list): sampling_rate = feature_extractor.sampling_rate
multi_modal_data = [multi_modal_data] audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
assert len(multi_modal_data) == 1
mm_data = {
if len(multi_modal_data) == 0: "audio":
return MultiModalKwargs() self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
processor = cached_processor_from_config(ctx.model_config)
sampling_rate = processor.feature_extractor.sampling_rate return ProcessorInputs(
prompt_text="<|startoftranscript|>" * num_audios,
audios = [audio for audio, _ in multi_modal_data] mm_data=mm_data,
)
kwargs = processor(audios,
sampling_rate=sampling_rate,
return_tensors="pt") class WhisperMultiModalProcessor(
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to( EncDecMultiModalProcessor[WhisperProcessingInfo]):
ctx.model_config.dtype)
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalKwargs(kwargs) feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) def create_encoder_prompt(
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) self,
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) prompt: Union[str, list[int]],
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( mm_data: MultiModalDataDict,
"audio", get_max_whisper_audio_tokens) ) -> Union[str, list[int]]:
# Strictly speaking, whisper encoder only accept audio features.
# We create a dummy encoder prompt here which will be padded to
# num_audio_tokens. So that we can create dummy data from this
# for encoder profiling.
return [0]
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
num_tokens = self.info.get_max_audio_tokens()
return [
PromptReplacement(
modality="audio",
target=[0],
replacement=[0] * num_tokens,
)
]
@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal): SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
if not isinstance(input_features, (torch.Tensor, list)): if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. " raise ValueError("Incorrect type of audio features. "
f"Got type: {type(input_features)}") f"Got type: {type(input_features)}")
input_features = [feat.to(self.dtype) for feat in input_features] input_features = torch.cat(
[feat.to(self.dtype) for feat in input_features])
return WhisperAudioInputs(input_features=input_features) return WhisperAudioInputs(input_features=input_features)
......
...@@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
) -> Union[str, list[int]]: ) -> Union[str, list[int]]:
"""Create input prompt for the encoder.""" """
Create input prompt for the encoder. HF processor will be applied on
this prompt during profiling and generation.
"""
raise NotImplementedError raise NotImplementedError
def apply( def apply(
......
...@@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]): ...@@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]):
f"({set(mm_max_tokens_per_item.keys())})") f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids = (
mm_inputs["prompt_token_ids"] if not is_encoder_data else
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
total_placeholders_by_modality = { total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders) modality: sum(item["length"] for item in placeholders)
...@@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]):
# V0 does not support chunked prefill. # V0 does not support chunked prefill.
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len: if total_len > seq_len and not is_encoder_data:
logger.warning( logger.warning(
"The context length (%d) of the model is too short " "The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case " "to hold the multi-modal embeddings in the worst case "
...@@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]): ...@@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]):
total_placeholders_by_modality) total_placeholders_by_modality)
return DummyData( return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), seq_data=SequenceData.from_prompt_token_counts(
(0, max(seq_len, total_len))),
multi_modal_data=None, multi_modal_data=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment