Unverified Commit d927dbcd authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Refactor Ultravox to use merged input processor (#11198)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent bddbbcb1
......@@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role':
'user',
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
'role': 'user',
'content': "<|audio|>\n" * audio_count + question
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
llm = LLM(model=model_name,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids
......
......@@ -214,7 +214,7 @@ MULTIMODAL_MODELS = {
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-decoder]
# TODO: Implement PP
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
......
......@@ -25,6 +25,7 @@ def server():
"--max-num-seqs",
"5",
"--enforce-eager",
"--trust-remote-code",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......
......@@ -16,7 +16,7 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
VLLM_PLACEHOLDER = "<|audio|>"
HF_PLACEHOLDER = "<|audio|>"
CHUNKED_PREFILL_KWARGS = {
......@@ -46,7 +46,8 @@ def audio(request):
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
"--trust-remote-code"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
......
......@@ -418,7 +418,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
return "<|audio|>"
if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
......
......@@ -3,41 +3,39 @@
import math
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union, cast)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
......@@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
def dummy_seq_data_for_ultravox(
ctx: InputContext,
seq_len: int,
audio_count: int,
):
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
return SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}
def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
return {"audio": [audio_and_sr] * audio_count}
def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return DummyData(seq_data, mm_dict, ranges)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]
if len(data) == 0:
return MultiModalKwargs()
# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalKwargs({"audio_embeds": data})
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().audio_processor.feature_extractor
def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
if sr != feature_extractor.sampling_rate:
try:
import librosa
......@@ -140,79 +92,93 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}
minimum_audio_length = feature_extractor.n_fft // 2 + 1
if len(audio) < minimum_audio_length:
# Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
single_audio_features = feature_extractor(
audio, sampling_rate=sr, padding="longest",
return_tensors="pt")["input_features"]
# Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0))
return MultiModalKwargs({"audio_features": audio_features})
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data or not mm_data.get("audio", None):
return super()._apply_hf_processor(prompt, mm_data,
mm_processor_kwargs)
audio_data = mm_data["audio"]
if not isinstance(audio_data, list):
audio_data = [audio_data]
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
audio_features, audio_token_len = [], []
processed_inputs = {}
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
audio_features.append(
processed_inputs.pop("audio_values").squeeze(0))
audio_token_len.append(
processed_inputs.pop("audio_token_len").item())
return dict(
**processed_inputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)
def _get_processor_data(
self,
mm_data: MultiModalDataDict,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Ultravox uses "audio" instead of "audios" as calling keyword
processor_data, passthrough_data = super()._get_processor_data(mm_data)
if "audios" in processor_data:
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement
def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
return placeholder * audio_token_len
return [
PromptReplacement(
modality="audio",
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]
def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
data = {"audio": [(audio, sampling_rate)] * audio_count}
feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
audio_token_counts = []
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_token_counts,
return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
class StackAudioFrames(nn.Module):
"""
......@@ -332,11 +298,9 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
......@@ -594,14 +594,10 @@ class BaseMultiModalProcessor(ABC):
return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
def _apply_hf_processor(
def _get_processor_data(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_data.items():
......@@ -619,6 +615,19 @@ class BaseMultiModalProcessor(ABC):
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
return processor_data, passthrough_data
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization
# instead of processor call
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data, passthrough_data = self._get_processor_data(mm_data)
assert callable(hf_processor)
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment