Unverified Commit d927dbcd authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Refactor Ultravox to use merged input processor (#11198)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent bddbbcb1
...@@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int): ...@@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{ messages = [{
'role': 'role': 'user',
'user', 'content': "<|audio|>\n" * audio_count + question
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
}] }]
prompt = tokenizer.apply_chat_template(messages, prompt = tokenizer.apply_chat_template(messages,
tokenize=False, tokenize=False,
add_generation_prompt=True) add_generation_prompt=True)
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count}) llm = LLM(model=model_name,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
......
...@@ -214,7 +214,7 @@ MULTIMODAL_MODELS = { ...@@ -214,7 +214,7 @@ MULTIMODAL_MODELS = {
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(), "fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-decoder] # [Encoder-decoder]
# TODO: Implement PP # TODO: Implement PP
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
......
...@@ -25,6 +25,7 @@ def server(): ...@@ -25,6 +25,7 @@ def server():
"--max-num-seqs", "--max-num-seqs",
"5", "5",
"--enforce-eager", "--enforce-eager",
"--trust-remote-code",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......
...@@ -16,7 +16,7 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3" ...@@ -16,7 +16,7 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple = Tuple[np.ndarray, int] AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" VLLM_PLACEHOLDER = "<|audio|>"
HF_PLACEHOLDER = "<|audio|>" HF_PLACEHOLDER = "<|audio|>"
CHUNKED_PREFILL_KWARGS = { CHUNKED_PREFILL_KWARGS = {
...@@ -46,7 +46,8 @@ def audio(request): ...@@ -46,7 +46,8 @@ def audio(request):
def server(request, audio_assets): def server(request, audio_assets):
args = [ args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}" f"--limit-mm-per-prompt=audio={len(audio_assets)}",
"--trust-remote-code"
] + [ ] + [
f"--{key.replace('_','-')}={value}" f"--{key.replace('_','-')}={value}"
for key, value in request.param.items() for key, value in request.param.items()
......
...@@ -418,7 +418,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -418,7 +418,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise TypeError(f"Unknown {modality} model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type == "ultravox":
return "<|reserved_special_token_0|>" return "<|audio|>"
if model_type == "qwen2_audio": if model_type == "qwen2_audio":
return (f"Audio {current_count}: " return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>") f"<|audio_bos|><|AUDIO|><|audio_eos|>")
......
...@@ -3,41 +3,39 @@ ...@@ -3,41 +3,39 @@
import math import math
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
TypedDict, Union, cast) Tuple, TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import InputContext
InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.utils import (cached_get_tokenizer, MultiModalDataDict,
consecutive_placeholder_ranges, MultiModalDataItems, ProcessorInputs,
repeat_and_pad_placeholder_tokens) PromptReplacement)
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map) merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25 _AUDIO_TOKENS_PER_SECOND = 6.25
...@@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): ...@@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
def dummy_seq_data_for_ultravox( class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
ctx: InputContext,
seq_len: int,
audio_count: int,
):
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)
return SequenceData.from_prompt_token_counts( def _get_feature_extractor(self) -> WhisperFeatureExtractor:
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count), return self._get_hf_processor().audio_processor.feature_extractor
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}
def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
return {"audio": [audio_and_sr] * audio_count}
def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return DummyData(seq_data, mm_dict, ranges)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]
if len(data) == 0:
return MultiModalKwargs()
# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalKwargs({"audio_embeds": data})
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
if sr != feature_extractor.sampling_rate: if sr != feature_extractor.sampling_rate:
try: try:
import librosa import librosa
...@@ -140,79 +92,93 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): ...@@ -140,79 +92,93 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
orig_sr=sr, orig_sr=sr,
target_sr=feature_extractor.sampling_rate) target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}
minimum_audio_length = feature_extractor.n_fft // 2 + 1 def _apply_hf_processor(
if len(audio) < minimum_audio_length: self,
# Not enough audio; pad it. prompt: str,
audio = np.pad(audio, (0, minimum_audio_length - len(audio))) mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
single_audio_features = feature_extractor( ) -> BatchFeature:
audio, sampling_rate=sr, padding="longest", if not mm_data or not mm_data.get("audio", None):
return_tensors="pt")["input_features"] return super()._apply_hf_processor(prompt, mm_data,
mm_processor_kwargs)
# Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0)) audio_data = mm_data["audio"]
if not isinstance(audio_data, list):
return MultiModalKwargs({"audio_features": audio_features}) audio_data = [audio_data]
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
audio_features, audio_token_len = [], []
processed_inputs = {}
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
audio_features.append(
processed_inputs.pop("audio_values").squeeze(0))
audio_token_len.append(
processed_inputs.pop("audio_token_len").item())
return dict(
**processed_inputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)
def _get_processor_data(
self,
mm_data: MultiModalDataDict,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Ultravox uses "audio" instead of "audios" as calling keyword
processor_data, passthrough_data = super()._get_processor_data(mm_data)
if "audios" in processor_data:
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement
def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
return placeholder * audio_token_len
return [
PromptReplacement(
modality="audio",
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]
def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): def _get_dummy_mm_inputs(
multi_modal_data = inputs.get("multi_modal_data") self,
if multi_modal_data is None or "audio" not in multi_modal_data: mm_counts: Mapping[str, int],
return inputs ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
if "multi_modal_placeholders" in inputs and "audio" in inputs[ audio_count = mm_counts["audio"]
"multi_modal_placeholders"]: audio = np.zeros(audio_len)
# The inputs already have placeholders. data = {"audio": [(audio, sampling_rate)] * audio_count}
return inputs
feature_extractor = whisper_feature_extractor(ctx) return ProcessorInputs(
audios = multi_modal_data["audio"] prompt_text="<|audio|>" * audio_count,
if not isinstance(audios, list): mm_data=data,
audios = [audios] mm_processor_kwargs={},
audio_token_counts = []
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_token_counts,
) )
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
class StackAudioFrames(nn.Module): class StackAudioFrames(nn.Module):
""" """
...@@ -332,11 +298,9 @@ class ModifiedWhisperEncoder(WhisperEncoder): ...@@ -332,11 +298,9 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens) "audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
...@@ -594,14 +594,10 @@ class BaseMultiModalProcessor(ABC): ...@@ -594,14 +594,10 @@ class BaseMultiModalProcessor(ABC):
return list( return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
def _apply_hf_processor( def _get_processor_data(
self, self,
prompt: str,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data = dict[str, Any]() processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]() passthrough_data = dict[str, Any]()
for k, v in mm_data.items(): for k, v in mm_data.items():
...@@ -619,6 +615,19 @@ class BaseMultiModalProcessor(ABC): ...@@ -619,6 +615,19 @@ class BaseMultiModalProcessor(ABC):
processor_data[f"{k}s"] = v processor_data[f"{k}s"] = v
else: else:
processor_data[k] = v processor_data[k] = v
return processor_data, passthrough_data
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization
# instead of processor call
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data, passthrough_data = self._get_processor_data(mm_data)
assert callable(hf_processor) assert callable(hf_processor)
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs( mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment