Unverified Commit 83f3c3bd authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Model] Refactor Phi-4-multimodal to use merged processor and support V1 (#15477)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d9737ca1
...@@ -1004,7 +1004,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -1004,7 +1004,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `microsoft/Phi-4-multimodal-instruct`, etc. * `microsoft/Phi-4-multimodal-instruct`, etc.
* ✅︎ * ✅︎
* *
* * ✅︎
- * `PixtralForConditionalGeneration` - * `PixtralForConditionalGeneration`
* Pixtral * Pixtral
* T + I<sup>+</sup> * T + I<sup>+</sup>
......
...@@ -89,7 +89,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: ...@@ -89,7 +89,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_path, model=model_path,
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=12800,
max_num_seqs=2, max_num_seqs=2,
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
......
...@@ -814,10 +814,13 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: ...@@ -814,10 +814,13 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_path, model=model_path,
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=5120,
max_num_seqs=2, max_num_seqs=2,
max_num_batched_tokens=12800,
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"dynamic_hd": 16},
limit_mm_per_prompt={"image": 1}, limit_mm_per_prompt={"image": 1},
) )
......
...@@ -503,11 +503,13 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -503,11 +503,13 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_path, model=model_path,
trust_remote_code=True, trust_remote_code=True,
max_model_len=10000, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"dynamic_hd": 4},
) )
placeholders = "".join(f"<|image_{i}|>" placeholders = "".join(f"<|image_{i}|>"
......
...@@ -18,6 +18,7 @@ transformers ...@@ -18,6 +18,7 @@ transformers
mistral_common >= 1.5.4 mistral_common >= 1.5.4
aiohttp aiohttp
starlette starlette
scipy
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
from typing import Optional from typing import Any, Optional
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from vllm.multimodal.audio import resample_audio from vllm.multimodal.audio import resample_audio_librosa
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner from ....conftest import HfRunner, VllmRunner
...@@ -43,6 +43,18 @@ def audio(request): ...@@ -43,6 +43,18 @@ def audio(request):
return AudioAsset(request.param) return AudioAsset(request.param)
def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
"""Convert kwargs to CLI args."""
args = []
for key, value in params_kwargs.items():
if isinstance(value, bool):
if value:
args.append(f"--{key.replace('_','-')}")
else:
args.append(f"--{key.replace('_','-')}={value}")
return args
@pytest.fixture(params=[ @pytest.fixture(params=[
pytest.param({}, marks=pytest.mark.cpu_model), pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS), pytest.param(CHUNKED_PREFILL_KWARGS),
...@@ -52,10 +64,7 @@ def server(request, audio_assets): ...@@ -52,10 +64,7 @@ def server(request, audio_assets):
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
json.dumps({"audio": len(audio_assets)}), "--trust-remote-code" json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
] + [ ] + params_kwargs_to_cli_args(request.param)
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
]
with RemoteOpenAIServer(MODEL_NAME, with RemoteOpenAIServer(MODEL_NAME,
args, args,
...@@ -136,9 +145,9 @@ def run_test( ...@@ -136,9 +145,9 @@ def run_test(
[hf_prompt], [hf_prompt],
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
audios=[(resample_audio(audio[0], audios=[(resample_audio_librosa(audio[0],
orig_sr=audio[1], orig_sr=audio[1],
target_sr=16000), 16000)]) target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios for _, hf_prompt, audio in prompts_and_audios
] ]
......
...@@ -181,7 +181,7 @@ def run_test( ...@@ -181,7 +181,7 @@ def run_test(
], ],
) )
@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [4096]) @pytest.mark.parametrize("max_model_len", [12800])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
...@@ -225,7 +225,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ...@@ -225,7 +225,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
], ],
) )
@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [10000]) @pytest.mark.parametrize("max_model_len", [25600])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize("num_logprobs", [10])
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
...@@ -258,7 +258,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, ...@@ -258,7 +258,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [10000]) @pytest.mark.parametrize("max_model_len", [12800])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize("num_logprobs", [10])
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
......
...@@ -274,6 +274,7 @@ def _test_processing_correctness_mistral( ...@@ -274,6 +274,7 @@ def _test_processing_correctness_mistral(
"nvidia/NVLM-D-72B", "nvidia/NVLM-D-72B",
"google/paligemma-3b-mix-224", "google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448", "google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409", "mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat", "Qwen/Qwen-VL-Chat",
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for phi4mm's multimodal preprocessing kwargs."""
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"])
# yapf: disable
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
[
({"dynamic_hd": 4}, 1329),
({"dynamic_hd": 16}, 4433),
# the default num_crops of phi-4-multimodal is 36
({}, 9585),
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
image_size = ctx.get_hf_config(
).embd_layer["image_embd_layer"]["crop_size"]
dummy_image_size = (image_size * 7, image_size * 7)
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count(
_IMAGE_PLACEHOLDER_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs
...@@ -482,11 +482,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -482,11 +482,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if modality in ("image", "image_embeds"): if modality in ("image", "image_embeds"):
if model_type == "chatglm": if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>" return "<|begin_of_image|><|endoftext|><|end_of_image|>"
if model_type == "phi3_v": if model_type in ("phi3_v", "phi4mm"):
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>" return f"<|image_{current_count}|>"
if model_type == "phi4mm":
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"): if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)" return "(<image>./</image>)"
if model_type in ("blip-2", "florence2", "fuyu", "paligemma", if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
...@@ -522,7 +519,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -522,7 +519,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type == "ultravox": if model_type == "ultravox":
return "<|audio|>" return "<|audio|>"
if model_type == "phi4mm": if model_type == "phi4mm":
return "<|endoftext11|>" # 200011 (see vocab.json in hf model) return f"<|audio_{current_count}|>"
if model_type in ("qwen2_audio", "qwen2_5_omni"): if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: " return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>") f"<|audio_bos|><|AUDIO|><|audio_eos|>")
......
...@@ -327,7 +327,7 @@ class Phi3VProcessingInfo(BaseProcessingInfo): ...@@ -327,7 +327,7 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Optional[ProcessorMixin], processor: Optional[ProcessorMixin] = None,
) -> int: ) -> int:
if processor is None: if processor is None:
processor = self.get_hf_processor() processor = self.get_hf_processor()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
import re from collections.abc import Iterable, Mapping, Sequence
from functools import lru_cache from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np import numpy as np
import scipy.signal
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.transforms as T from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
from PIL import Image SequenceFeatureExtractor, SiglipVisionConfig)
from transformers import PretrainedConfig, SiglipVisionConfig
from transformers.utils import logging
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.inputs.data import TokenInputs, token_inputs
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.sequence import IntermediateTensors, SequenceData MultiModalKwargs, NestedTensors)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems,
ImageProcessorItems, ImageSize,
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsV0Only from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding from .phi4mm_audio import AudioEmbedding
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
# <|endoftext10|> (see vocab.json in hf model) # <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010 _IMAGE_PLACEHOLDER_TOKEN_ID = 200010
...@@ -43,115 +44,19 @@ _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 ...@@ -43,115 +44,19 @@ _IMAGE_PLACEHOLDER_TOKEN_ID = 200010
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 _AUDIO_PLACEHOLDER_TOKEN_ID = 200011
_AUDIO_MAX_SOUNDFILE_SIZE = 241_000 _AUDIO_MAX_SOUNDFILE_SIZE = 241_000
DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz
DYNAMIC_HD = 16
AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>"
IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>"
SIGLIP_NAME = "siglip-so400m-patch14-448" SIGLIP_NAME = "siglip-so400m-patch14-448"
VISION_ENCODER_TO_PROCESSING_CONFIG = { VISION_ENCODER_TO_PROCESSING_CONFIG = {
'siglip-so400m-patch14-448': { 'siglip-so400m-patch14-448': {
'dynamic_hd': 16,
'vit_image_size': 448, 'vit_image_size': 448,
'vit_patch_size': 14, 'vit_patch_size': 14,
'token_compression_factor': 2, 'token_compression_factor': 2,
}, },
} }
logger = logging.get_logger(__name__)
# This is a workaround to prevent text (user input) + audio + image
# from being used in the same prompt.
# It includes token ids for "/n" and tokens in added_tokens_decoder
# from the tokenizer_confg.json file.
NON_USER_INPUT_TOKENS = {
198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022,
200023, 200024, 200025, 200026, 200027, 200028
}
def get_max_dummy_image(ctx: InputContext): def _get_padding_size(orig_width: int, orig_height: int, target_height: int,
hf_config = ctx.get_hf_config() target_width: int):
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
dynamic_hd_size = prepro_config['dynamic_hd']
vit_image_size = prepro_config['vit_image_size']
max_side = vit_image_size * dynamic_hd_size
dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side)
return dummy_image
# image token length
def get_max_phi4mm_image_tokens(ctx: InputContext):
dummy_image = get_max_dummy_image(ctx)
hf_config = ctx.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
dynamic_hd_size = prepro_config['dynamic_hd']
vit_image_size = prepro_config['vit_image_size']
vit_patch_size = prepro_config['vit_patch_size']
token_compression_factor = prepro_config['token_compression_factor']
image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size,
vit_image_size,
vit_patch_size,
token_compression_factor)
return image_num_tokens
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def _find_target_aspect_ratio(image, image_size, max_num, min_num):
orig_width, orig_height = image.size
w_crop_num = math.ceil(orig_width / float(image_size))
h_crop_num = math.ceil(orig_height / float(image_size))
if w_crop_num * h_crop_num > max_num:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for i in range(1, max_num + 1)
for j in range(1, max_num + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
logger.debug("target_aspect_ratio: %s", target_aspect_ratio)
else:
target_width = image_size * w_crop_num
target_height = image_size * h_crop_num
target_aspect_ratio = (w_crop_num, h_crop_num)
return target_aspect_ratio, target_height, target_width
def _get_padding_size(image, target_height, target_width):
orig_width, orig_height = image.size
ratio_width = target_width / orig_width ratio_width = target_width / orig_width
ratio_height = target_height / orig_height ratio_height = target_height / orig_height
...@@ -164,181 +69,6 @@ def _get_padding_size(image, target_height, target_width): ...@@ -164,181 +69,6 @@ def _get_padding_size(image, target_height, target_width):
return padding_height, padding_width return padding_height, padding_width
def dynamic_preprocess(image,
min_num=1,
max_num=12,
image_size=384,
mask_size=27):
target_aspect_ratio, target_height, target_width =\
_find_target_aspect_ratio(
image, image_size, max_num, min_num)
padding_height, padding_width = _get_padding_size(image, target_height,
target_width)
# Calculate the ratio
orig_width, orig_height = image.size
ratio_width = target_width / orig_width
ratio_height = target_height / orig_height
if ratio_width < ratio_height:
new_size = (target_width, int(orig_height * ratio_width))
else:
new_size = (int(orig_width * ratio_height), target_height)
attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]),
int(mask_size * target_aspect_ratio[0])))
if padding_width >= 14:
attention_mask[:, -math.floor(padding_width / 14):] = 0
if padding_height >= 14:
attention_mask[-math.floor(padding_height / 14):, :] = 0
assert attention_mask.sum(
) > 0, f'attention mask is empty {attention_mask}'
if min(new_size[1], target_height) < 10 or min(new_size[0],
target_width) < 10:
raise ValueError(f'the aspect ratio is very extreme {new_size}')
image = T.functional.resize(
image,
[new_size[1], new_size[0]],
)
resized_img = T.functional.pad(image,
[0, 0, padding_width, padding_height],
fill=[255, 255, 255])
return resized_img, attention_mask
def pad_to_max_num_crops(images, max_crops=5):
"""
images: B x 3 x H x W, B<=max_crops
"""
B, _, H, W = images.shape
if max_crops > B:
pad = torch.zeros(max_crops - B,
3,
H,
W,
dtype=images.dtype,
device=images.device)
images = torch.cat([images, pad], dim=0)
return images
def pad_mask_to_max_num_crops(masks, max_crops=5):
B, H, W = masks.shape
if max_crops > B:
pad = torch.ones(max_crops - B,
H,
W,
dtype=masks.dtype,
device=masks.device)
masks = torch.cat([masks, pad], dim=0)
return masks
def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size):
# Basic settings.
img_processor = T.Compose([
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Dynamic HD
base_resolution = vit_resolution
images = [image.convert('RGB') for image in images]
# cover 384 and 448 resolution
mask_resolution = base_resolution // vit_patch_size
elems, image_attention_masks = [], []
for im in images:
elem, attention_mask = dynamic_preprocess(im,
max_num=dynamic_hd_size,
image_size=base_resolution,
mask_size=mask_resolution)
elems.append(elem)
image_attention_masks.append(attention_mask)
hd_images = [img_processor(im) for im in elems]
global_image = [
torch.nn.functional.interpolate(
im.unsqueeze(0).float(),
size=(base_resolution, base_resolution),
mode='bicubic',
).to(im.dtype) for im in hd_images
]
shapes = [[im.size(1), im.size(2)] for im in hd_images]
mask_shapes = [[mask.size(0), mask.size(1)]
for mask in image_attention_masks]
global_attention_mask = [
torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images
]
hd_images_reshape = [
im.reshape(1, 3, h // base_resolution, base_resolution,
w // base_resolution, base_resolution).permute(
0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution,
base_resolution).contiguous()
for im, (h, w) in zip(hd_images, shapes)
]
attention_masks_reshape = [
mask.reshape(1, h // mask_resolution, mask_resolution,
w // mask_resolution, mask_resolution).permute(
0, 1, 3, 2, 4).reshape(-1, mask_resolution,
mask_resolution).contiguous()
for mask, (h, w) in zip(image_attention_masks, mask_shapes)
]
# NOTE token compression is hard coded here, and odd numbers seems to fail
downsample_attention_masks = [
mask[:, 0::2,
0::2].reshape(1, h // mask_resolution, w // mask_resolution,
mask_resolution // 2 + mask_resolution % 2,
mask_resolution // 2 + mask_resolution % 2).permute(
0, 1, 3, 2, 4)
for mask, (h, w) in zip(attention_masks_reshape, mask_shapes)
]
downsample_attention_masks = [
mask.reshape(mask.size(1) * mask.size(2),
mask.size(3) * mask.size(4))
for mask in downsample_attention_masks
]
# NOTE hard coded number of tokens
num_img_tokens = [
256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16
for mask in downsample_attention_masks
]
hd_images_reshape = [
torch.cat([_global_image] + [_im], dim=0)
for _global_image, _im in zip(global_image, hd_images_reshape)
]
hd_masks_reshape = [
torch.cat([_global_mask] + [_mask],
dim=0) for _global_mask, _mask in zip(
global_attention_mask, attention_masks_reshape)
]
max_crops = max([img.size(0) for img in hd_images_reshape])
image_transformed = [
pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape
]
image_transformed = torch.stack(image_transformed, dim=0)
mask_transformed = [
pad_mask_to_max_num_crops(mask, max_crops) \
for mask in hd_masks_reshape
]
mask_transformed = torch.stack(mask_transformed, dim=0)
returned_input_image_embeds = image_transformed
returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
returned_image_attention_mask = mask_transformed
returned_num_img_tokens = num_img_tokens
data = {
"pixel_values": returned_input_image_embeds,
"image_sizes": returned_image_sizes,
"image_attention_mask": returned_image_attention_mask,
"num_img_tokens": returned_num_img_tokens,
}
return data
def get_navit_vision_model(layer_idx: int = -1, **kwargs): def get_navit_vision_model(layer_idx: int = -1, **kwargs):
vision_config = { vision_config = {
"hidden_size": 1152, "hidden_size": 1152,
...@@ -492,7 +222,7 @@ class Phi4MMImageEncoder(nn.Module): ...@@ -492,7 +222,7 @@ class Phi4MMImageEncoder(nn.Module):
def forward(self, pixel_values: torch.FloatTensor, def forward(self, pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor, image_sizes: torch.Tensor,
image_attention_mask: torch.Tensor) -> torch.FloatTensor: image_attention_mask: torch.Tensor) -> list[torch.FloatTensor]:
""" """
process image and return vision embeddings. process image and return vision embeddings.
...@@ -656,785 +386,528 @@ class Phi4MMImageEncoder(nn.Module): ...@@ -656,785 +386,528 @@ class Phi4MMImageEncoder(nn.Module):
for _output_img in output_imgs: for _output_img in output_imgs:
img_feature_proj = self.img_projection( img_feature_proj = self.img_projection(
_output_img.to(target_device).to(target_dtype)) _output_img.to(target_device).to(target_dtype))
img_set_tensor.append(img_feature_proj) img_set_tensor.append(img_feature_proj.squeeze(0))
return img_set_tensor return img_set_tensor
class Phi4MMAudioFeatureInputs(TypedDict): class Phi4MMImagePixelInputs(TypedDict):
type: Literal["audio_features"] type: Literal["pixel_values"]
data: Tuple[NestedTensors] data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `((batch_size, num_audios, 80, M), )"""
class Phi4MMAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
Args:
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
n_fft (int): FFT size. int > 0 [scalar]
n_mel (int): Mel filter size. int > 0 [scalar]
fmin (float): lowest frequency (in Hz). If None use 0.0.
float >= 0 [scalar]
fmax: highest frequency (in Hz). If None use sample_rate / 2.
float >= 0 [scalar]
Returns
out (numpy.ndarray): Mel transform matrix
[shape=(n_mels, 1 + n_fft/2)]
""" """
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
bank_width = int(n_fft // 2 + 1) Note that `num_patches` may be different per batch and image,
if fmax is None: in which case the data is passed as a list instead of a batched tensor.
fmax = sample_rate / 2 """
if fmin is None:
fmin = 0
assert fmin >= 0, "fmin cannot be negative"
assert (fmin < fmax <=
sample_rate / 2), "fmax must be between (fmin, samplerate / 2]"
def mel(f):
return 1127.0 * np.log(1.0 + f / 700.0)
def bin2mel(fft_bin):
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
def f2bin(f):
return int((f * n_fft / sample_rate) + 0.5)
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
klo = f2bin(fmin) + 1
khi = f2bin(fmax)
khi = max(khi, klo)
# Spec 2: SpeechLib uses triangles in Mel space
mlo = mel(fmin)
mhi = mel(fmax)
m_centers = np.linspace(mlo, mhi, n_mels + 2)
ms = (mhi - mlo) / (n_mels + 1)
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
for m in range(0, n_mels):
left = m_centers[m]
center = m_centers[m + 1]
right = m_centers[m + 2]
for fft_bin in range(klo, khi):
mbin = bin2mel(fft_bin)
if left < mbin < right:
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
return matrix
class LogFbankProcessor:
def __init__(self):
self._eightk_method = "fillzero"
self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
self._hamming400 = np.hamming(400) # for 16k audio
self._hamming200 = np.hamming(200) # for 8k audio
def extract_spectrogram(self, wav, fs): image_sizes: torch.Tensor
"""Extract spectrogram features from waveform. """
Args: Shape: `(batch_size * num_images, 2)`
wav (1D array): waveform of the input
fs (int): sampling rate of the waveform, 16000 or 8000.
If fs=8000, the waveform will be resampled to 16000Hz.
Output:
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
D=80, and T is the number of frames.
"""
if wav.ndim > 1:
wav = np.squeeze(wav)
# by default, we extract the mean if stereo This should be in `(height, width)` format.
if len(wav.shape) == 2: """
wav = wav.mean(1)
# Resample to 16000 or 8000 if needed num_img_tokens: list[int]
if fs > 16000: """Shape: `(batch_size * num_images)`"""
wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
fs = 16000
elif 8000 < fs < 16000:
wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
fs = 8000
elif fs < 8000:
raise RuntimeError(f"Unsupported sample rate {fs}")
if fs == 8000:
if self._eightk_method == "resample":
# Input audio is 8 kHz. Convert to 16 kHz before feature
# extraction
wav = scipy.signal.resample_poly(wav, 2, 1)
fs = 16000
# Do nothing here for fillzero method
elif fs != 16000:
# Input audio is not a supported sample rate.
raise RuntimeError(
f"Input data using an unsupported sample rate: {fs}")
preemphasis = 0.97
if fs == 8000:
n_fft = 256
win_length = 200
hop_length = 80
fft_window = self._hamming200
elif fs == 16000:
n_fft = 512
win_length = 400
hop_length = 160
fft_window = self._hamming400
# Spec 1: SpeechLib cut remaining sample insufficient for a hop
n_batch = (wav.shape[0] - win_length) // hop_length + 1
# Here we don't use stride_tricks since the input array may not satisfy
# memory layout requirement and we need writeable output
# Here we only use list of views before copy to destination
# so it is more efficient than broadcasting
y_frames = np.array(
[
wav[_stride:_stride + win_length]
for _stride in range(0, hop_length * n_batch, hop_length)
],
dtype=np.float32,
)
# Spec 2: SpeechLib applies preemphasis within each batch image_attention_mask: torch.Tensor
y_frames_prev = np.roll(y_frames, 1, axis=1) """Shape: `(batch_size * num_images, H_mask, W_mask)`"""
y_frames_prev[:, 0] = y_frames_prev[:, 1]
y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
S = np.fft.rfft(fft_window * y_frames, n=n_fft,
axis=1).astype(np.complex64)
if fs == 8000: class Phi4MMImageEmbeddingInputs(TypedDict):
# Need to pad the output to look like 16 kHz data but with zeros in type: Literal["image_embeds"]
# the 4 to 8 kHz bins. data: Union[torch.Tensor, List[torch.Tensor]]
frames, bins = S.shape """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
padarray = np.zeros((frames, bins))
S = np.concatenate((S[:, 0:-1], padarray),
axis=1) # Nyquist bin gets set to zero
spec = np.abs(S).astype(np.float32) `hidden_size` must match the hidden size of language model backbone.
return spec """
def extract_features(self, wav, fs):
"""Extract log filterbank features from waveform.
Args:
wav (1D array): waveform of the input
fs (int): sampling rate of the waveform, 16000 or 8000.
If fs=8000, the waveform will be resampled to 16000Hz.
Output:
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
D=80, and T is the number of frames.
"""
spec = self.extract_spectrogram(wav, fs)
spec_power = spec**2
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) class Phi4MMAudioFeatureInputs(TypedDict):
log_fbank = np.log(fbank_power).astype(np.float32) type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size * num_audios, 80, M)"""
return log_fbank
class Phi4MMAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
@lru_cache
def audio_feature_extractor() -> LogFbankProcessor:
# Creates an instance of the audio processor, needed to extract the
# the audio features from the sound file
# LRU cache ensures that we only make one copy
return LogFbankProcessor()
Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
vit_patch_size, token_compression_factor):
"""
compute the number of tokens an image is expected to take up considering
the image encoder architecture and exclude output features containing
only padding pixels
for siglip, vit_image_size=448, vit_patch_size=14, so output will be def cat_with_pad(tensors, dim, padding_value=0):
32x32 feature map
NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
"""
assert vit_image_size % vit_patch_size == 0, \
"vit_image_size must be divisible by vit_patch_size"
assert vit_image_size // vit_patch_size % token_compression_factor == 0, \
"vit_image_size // vit_patch_size must be divisible by "\
"token_compression_factor"
target_aspect_ratio, target_height, target_width = (
_find_target_aspect_ratio(image,
vit_image_size,
dynamic_hd_size,
min_num=1))
assert target_aspect_ratio[
0] * vit_image_size == target_width, \
f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}"
assert target_aspect_ratio[
1] * vit_image_size == target_height, \
f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}"
assert (target_height % vit_image_size == 0
and target_width % vit_image_size == 0)
padding_height, padding_width = _get_padding_size(image, target_height,
target_width)
assert padding_width == 0 or padding_height == 0, \
"padding_width or padding_height must be 0"
target_feat_width = target_width // vit_patch_size
target_feat_height = target_height // vit_patch_size
if padding_width >= vit_patch_size:
assert padding_height == 0, "padding_height not 0"
non_pad_feat_width = target_feat_width - math.floor(
padding_width / vit_patch_size)
non_pad_feat_height = target_feat_height
elif padding_height >= vit_patch_size:
assert padding_width == 0, "padding_width not 0"
non_pad_feat_height = target_feat_height - math.floor(
padding_height / vit_patch_size)
non_pad_feat_width = target_feat_width
else:
# small padding shorter than a vit patch
non_pad_feat_width = target_feat_width
non_pad_feat_height = target_feat_height
feat_width = non_pad_feat_width // token_compression_factor
feat_height = non_pad_feat_height // token_compression_factor
# NOTE it's possible that the non-padding feature is not divisible
if non_pad_feat_width % token_compression_factor != 0:
feat_width += 1
if non_pad_feat_height % token_compression_factor != 0:
feat_height += 1
num_hd_patch_tokens = feat_width * feat_height
num_hd_newline_tokens = feat_height
vit_feature_size = vit_image_size // vit_patch_size
num_global_image_tokens = (vit_feature_size // token_compression_factor)**2
num_sep_tokens = 1
num_global_image_newline_tokens = \
vit_feature_size // token_compression_factor
return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens +
num_hd_newline_tokens + num_global_image_newline_tokens)
def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]:
""" """
Compute the output size of the `extract_features` method. cat along dim, while pad to max for all other dims
Args:
wav_length (int): Length of the input waveform in samples.
fs (int): Sampling rate of the waveform, either 16000 or 8000.
Returns:
tuple (int, int): Output size as (T, D), where:
T: Number of time frames.
D: Number of Mel filterbank bins (80).
""" """
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in
tensors[1:]), "All tensors must have the same number of dimensions"
# Resample to 16000 or 8000 if needed out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
if fs > 16000: out_size[dim] = sum(t.shape[dim] for t in tensors)
wav_length //= fs // 16000 output = tensors[0].new_full(out_size, padding_value)
fs = 16000
elif 8000 <= fs < 16000:
# We'll resample to 16K from 8K
wav_length *= 2
fs = 16000
elif fs < 8000:
raise RuntimeError(f"Unsupported sample rate {fs}")
# Spectrogram parameters for 16 kHz index = 0
win_length = 400 # Frame length in samples for t in tensors:
hop_length = 160 # Frame shift in samples # Create a slice list where every dimension except dim is full slice
mel_bins = 80 # Number of mel filterbank bins slices = [slice(0, t.shape[d]) for d in range(ndim)]
# Update only the concat dimension slice
slices[dim] = slice(index, index + t.shape[dim])
# Calculate number of frames (T) output[slices] = t
T = (wav_length - win_length) // hop_length + 1 index += t.shape[dim]
if T < 1:
raise ValueError("Waveform too short for given parameters.")
# Return time frames (T) and mel bins (D) return output
return T, mel_bins
def _get_audio_embed_sizes(audios, ctx: InputContext): class Phi4MMProcessingInfo(BaseProcessingInfo):
"""
Get the audio embedding sizes for each audio file.
Args: def get_hf_processor(
audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of self,
waveform and sample rate. *,
ctx (InputContext): Input context. dynamic_hd: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
if dynamic_hd is not None:
kwargs["dynamic_hd"] = dynamic_hd
Returns: return self.ctx.get_hf_processor(**kwargs)
List[int]: List of audio embedding sizes.
"""
audio_embed_sizes = []
for audio in audios:
audio_data, sf = audio
audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf)
audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(),
audio_frames)
audio_embed_sizes.append(audio_embed_size)
return audio_embed_sizes
@property
def image_tokens(self) -> list[str]:
return [f"<|image_{i+1}|>" for i in range(100)]
def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): @property
""" def audio_tokens(self) -> list[str]:
The following will search for `<|audio_{idx}|>` tokens and return [f"<|audio_{i+1}|>" for i in range(100)]
return a mapping of audio placeholder tokens to audio placeholder token ids
based on the size of the audio embeddings.
Args: def get_dynamic_hd(
audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of self,
waveform and sample rate. processor: Optional[ProcessorMixin] = None,
ctx (InputContext): Input context. ) -> int:
prompt_str (str): The prompt string. if processor is None:
processor = self.get_hf_processor()
image_processor = processor.image_processor
return image_processor.dynamic_hd
Returns: def get_feature_extractor(self) -> SequenceFeatureExtractor:
Dict[str, List[int]]: Mapping of audio placeholder tokens to audio return self.get_hf_processor().audio_processor
placeholder token ids.
""" def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
if len(audios) == 0: return {"audio": None, "image": None}
return {}
audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str)
audio_ids = [int(audio_id) for audio_id in audio_ids]
assert len(audio_ids) == len(
audio_embed_sizes
), "Number of audio tokens and audio features do not match"
assert tuple(audio_ids) == tuple(range(1,
len(audio_ids) +
1)), "Audio ids are not in order!"
audio_id_to_input_ids = {
f"<|audio_{audio_id}|>":
[_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes)
}
return audio_id_to_input_ids def get_mm_max_tokens_per_item(
self,
seq_len: int,
def _count_image_tokens(images, ctx: InputContext): mm_counts: Mapping[str, int],
hf_config = ctx.get_hf_config() ) -> Mapping[str, int]:
vision_encoder_name = hf_config.img_processor return {
if vision_encoder_name is None: "image": self.get_max_image_tokens(),
vision_encoder_name = SIGLIP_NAME "audio": self.get_max_audio_tokens(),
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] }
dynamic_hd_size = prepro_config['dynamic_hd']
vit_image_size = prepro_config['vit_image_size']
vit_patch_size = prepro_config['vit_patch_size']
token_compression_factor = prepro_config['token_compression_factor']
image_token_counts = [
_compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
vit_patch_size, token_compression_factor)
for image in images
]
return image_token_counts
def _get_image_id_to_input_ids(images, prompt, ctx: InputContext):
if len(images) == 0:
return {}
image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt)
image_ids = [int(image_id) for image_id in image_ids]
assert len(image_ids) == len(
set(image_ids)), "Duplicate image tokens in prompt"
assert len(images) == len(
image_ids), "Number of images and image tokens in prompt do not match"
# NOTE the following assertion is not strictly necessary
assert tuple(image_ids) == tuple(range(1,
len(image_ids) +
1)), "Image ids are not in order"
image_token_counts = _count_image_tokens(images, ctx)
image_id_to_input_ids = {
f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens
for image_id, num_tokens in zip(image_ids, image_token_counts)
}
return image_id_to_input_ids
def get_max_audio_tokens(self) -> int:
sr = self.get_feature_extractor().sampling_rate
num_frames = self.get_audio_num_frames(_AUDIO_MAX_SOUNDFILE_SIZE, sr)
return self._compute_audio_embed_size(num_frames)
def input_processor_for_phi4mm(ctx: InputContext, def get_max_image_tokens(self) -> int:
inputs: DecoderOnlyInputs) -> TokenInputs: target_width, target_height = self.get_image_size_with_most_features()
""" return self.get_num_image_tokens(image_width=target_width,
Implements the input processor, which transforms the input prompt ids image_height=target_height)
to include the audio placeholder token. This will become the `input_ids`
in `forward` for the model.
Args: def _find_target_aspect_ratio(
ctx (InputContext): Input context. self,
inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) orig_width: int,
to process. orig_height: int,
image_size: int,
max_num: int,
min_num: int,
):
w_crop_num = math.ceil(orig_width / float(image_size))
h_crop_num = math.ceil(orig_height / float(image_size))
if w_crop_num * h_crop_num > max_num:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for i in range(1, max_num + 1)
for j in range(1, max_num + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
image_processor = self.get_hf_processor().image_processor
target_aspect_ratio = image_processor.find_closest_aspect_ratio(
aspect_ratio,
target_ratios,
orig_width,
orig_height,
image_size,
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
else:
target_width = image_size * w_crop_num
target_height = image_size * h_crop_num
target_aspect_ratio = (w_crop_num, h_crop_num)
return target_aspect_ratio, target_height, target_width
Returns: def _compute_num_image_tokens(
TokenInputs: Processed inputs self,
""" orig_width: int,
multi_modal_data = inputs.get("multi_modal_data") orig_height: int,
if (multi_modal_data is None or dynamic_hd_size: int,
("audio" not in multi_modal_data and "image" not in multi_modal_data)): vit_image_size: int,
# pure text input, so no need to do pre-processing vit_patch_size: int,
return inputs token_compression_factor: int = 2,
):
prompt_str = inputs.get("prompt") """
prompt_token_ids = inputs.get("prompt_token_ids") compute the number of tokens an image is expected to take up considering
# for offline_inference, we will get str input and we parse MM special the image encoder architecture and exclude output features containing
# tokens from it only padding pixels
# (ignore prompt_token_ids)
# for OAI server, we will get prompt_token_ids, where MM special tokens
# are already parsed
if 'audio' in multi_modal_data:
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
if prompt_str is not None:
audio_id_to_input_ids = _get_audio_id_to_input_ids(
audios, ctx, prompt_str=prompt_str)
audio_embed_sizes = []
elif prompt_token_ids is not None:
audio_id_to_input_ids = {}
audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
else:
audio_id_to_input_ids = {}
audio_embed_sizes = []
if 'image' in multi_modal_data:
# PIL Image or list of PIL Images
images = multi_modal_data["image"]
if not isinstance(images, list):
images = [images]
if prompt_str is not None:
image_id_to_input_ids = _get_image_id_to_input_ids(
images, prompt_str, ctx)
image_token_counts = []
elif prompt_token_ids is not None:
image_id_to_input_ids = {}
image_token_counts = _count_image_tokens(images, ctx)
else:
image_id_to_input_ids = {}
image_token_counts = []
# Handle the case where the prompt is a string and we need to manually
# tokenize it.
# In this case, the `audio_id_to_input_ids` dict will be mapping from
# an audio placeholder
# string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the
# given audio length.
if prompt_str:
pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)"
prompt_chunk_strings = re.split(pattern, prompt_str)
prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""]
# Create the new input_ids with the placeholder image and audio
# tokens inserted
tokenizer = cached_tokenizer_from_config(ctx.model_config)
input_ids = []
has_imag, has_audio, has_user_text_input = False, False, False
for prompt_chunk_string in prompt_chunk_strings:
if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string):
input_ids.extend(image_id_to_input_ids[prompt_chunk_string])
has_imag = True
elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string):
input_ids.extend(audio_id_to_input_ids[prompt_chunk_string])
has_audio = True
else:
curr_token_ids = tokenizer(prompt_chunk_string).input_ids
if not has_user_text_input:
for token_id in curr_token_ids:
if token_id not in NON_USER_INPUT_TOKENS:
has_user_text_input = True
break
input_ids.extend(curr_token_ids)
if has_audio and has_imag and has_user_text_input:
raise ValueError(
"Phi4MMForCausalLM does not support text + audio + image" +
" inputs in the same prompt")
# Handle the case where the prompt is already tokenized
else:
assert prompt_token_ids is not None, \
"If string prompt isn't provided, prompt_token_ids must be"
i = 0
input_ids = prompt_token_ids
# only needed for later assertion
img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0
image_token_count_iter = iter(image_token_counts)
audio_embed_size_iter = iter(audio_embed_sizes)
while i < len(input_ids):
token_id = input_ids[i]
if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID:
token_count = next(audio_embed_size_iter)
audio_cnt += 1
elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID:
token_count = next(image_token_count_iter)
img_cnt += 1
else:
user_text_input_cnt += 1 if token_id not in \
NON_USER_INPUT_TOKENS else 0
i += 1
continue
tokens = [token_id] * token_count
input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
i += token_count
if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0:
raise ValueError(
"Phi4MMForCausalLM does not support text + audio + image" +
" inputs in the same prompt")
# If the below assertion fails, it might be that input pure-text
# messages contain image/audio special tokens literally
# (<|endoftext10|>, <|endoftext11|>).
assert (img_cnt == len(image_token_counts)), (
f"Number of image tokens in prompt_token_ids ({img_cnt}) "
f"does not match number of images ({len(image_token_counts)})")
assert (audio_cnt == len(audio_embed_sizes)), (
f"Number of audio tokens in prompt_token_ids ({audio_cnt}) "
f"does not match number of audios ({len(audio_embed_sizes)})")
# NOTE: Create a defensive copy of the original inputs
return token_inputs(
prompt_token_ids=input_ids,
prompt=prompt_str,
multi_modal_data=multi_modal_data,
)
for siglip, vit_image_size=448, vit_patch_size=14, so output will be
32x32 feature map
NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
"""
assert vit_image_size % vit_patch_size == 0, (
"vit_image_size must be divisible by vit_patch_size")
assert (vit_image_size // vit_patch_size %
token_compression_factor == 0), (
"vit_image_size // vit_patch_size must be divisible by "
"token_compression_factor")
target_aspect_ratio, target_height, target_width = (
self._find_target_aspect_ratio(orig_width,
orig_height,
vit_image_size,
dynamic_hd_size,
min_num=1))
assert target_aspect_ratio[0] * vit_image_size == target_width, (
f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}")
assert target_aspect_ratio[1] * vit_image_size == target_height, (
f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}")
assert (target_height % vit_image_size == 0
and target_width % vit_image_size == 0)
padding_height, padding_width = _get_padding_size(
orig_width, orig_height, target_height, target_width)
assert padding_width == 0 or padding_height == 0, \
"padding_width or padding_height must be 0"
target_feat_width = target_width // vit_patch_size
target_feat_height = target_height // vit_patch_size
if padding_width >= vit_patch_size:
assert padding_height == 0, "padding_height not 0"
non_pad_feat_width = target_feat_width - math.floor(
padding_width / vit_patch_size)
non_pad_feat_height = target_feat_height
elif padding_height >= vit_patch_size:
assert padding_width == 0, "padding_width not 0"
non_pad_feat_height = target_feat_height - math.floor(
padding_height / vit_patch_size)
non_pad_feat_width = target_feat_width
else:
# small padding shorter than a vit patch
non_pad_feat_width = target_feat_width
non_pad_feat_height = target_feat_height
feat_width = non_pad_feat_width // token_compression_factor
feat_height = non_pad_feat_height // token_compression_factor
# NOTE it's possible that the non-padding feature is not divisible
if non_pad_feat_width % token_compression_factor != 0:
feat_width += 1
if non_pad_feat_height % token_compression_factor != 0:
feat_height += 1
num_hd_patch_tokens = feat_width * feat_height
num_hd_newline_tokens = feat_height
vit_feature_size = vit_image_size // vit_patch_size
num_global_image_tokens = (vit_feature_size //
token_compression_factor)**2
num_sep_tokens = 1
num_global_image_newline_tokens = \
vit_feature_size // token_compression_factor
return (num_global_image_tokens + num_sep_tokens +
num_hd_patch_tokens + num_hd_newline_tokens +
num_global_image_newline_tokens)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin] = None,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[
vision_encoder_name]
vit_image_size = prepro_config['vit_image_size']
vit_patch_size = prepro_config['vit_patch_size']
token_compression_factor = prepro_config['token_compression_factor']
dynamic_hd_size = self.get_dynamic_hd(processor=processor)
image_num_tokens = self._compute_num_image_tokens(
image_width,
image_height,
dynamic_hd_size=dynamic_hd_size,
vit_image_size=vit_image_size,
vit_patch_size=vit_patch_size,
token_compression_factor=token_compression_factor,
)
def _compute_audio_embed_size(hf_config, audio_frames): return image_num_tokens
"""
Compute the audio embedding size based on the audio frames and
compression rate.
"""
compression_rate = hf_config.embd_layer['audio_embd_layer'][
'compression_rate']
# NOTE: this is a hard-coded value but might be configurable in the future
qformer_compression_rate = 1
integer = audio_frames // compression_rate
remainder = audio_frames % compression_rate
result = integer if remainder == 0 else integer + 1 def get_image_size_with_most_features(
self,
processor: Optional[ProcessorMixin] = None,
) -> ImageSize:
hf_config = self.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[
vision_encoder_name]
vit_image_size = prepro_config['vit_image_size']
max_side = vit_image_size * self.get_dynamic_hd(processor=processor)
return ImageSize(height=max_side, width=vit_image_size)
def get_audio_num_frames(self, audio_len: int, sr: float) -> int:
"""
Compute the output size of the `extract_features` method.
integer = result // qformer_compression_rate Args:
remainder = result % qformer_compression_rate audio_len (int): Length of the input waveform in samples.
result = integer if remainder == 0 else integer + 1 # qformer compression sr (float): Sampling rate of the waveform, either 16000 or 8000.
return result Returns:
tuple (int, int): Output size as (T, D), where:
T: Number of time frames.
D: Number of Mel filterbank bins (80).
"""
# Resample to 16000 or 8000 if needed
if sr > 16000:
audio_len //= sr // 16000
elif 8000 <= sr < 16000:
# We'll resample to 16K from 8K
audio_len *= 2
elif sr < 8000:
raise RuntimeError(f"Unsupported sample rate {sr}")
# Spectrogram parameters for 16 kHz
win_length = 400 # Frame length in samples
hop_length = 160 # Frame shift in samples
# Calculate number of frames (T)
num_frames = (audio_len - win_length) // hop_length + 1
if num_frames < 1:
raise ValueError("Waveform too short for given parameters.")
# Return time frames (T)
return num_frames
def _compute_audio_embed_size(self, audio_frames: int) -> int:
"""
Compute the audio embedding size based on the audio frames and
compression rate.
"""
hf_config = self.get_hf_config()
compression_rate = hf_config.embd_layer['audio_embd_layer'][
'compression_rate']
# NOTE: this is a hard-coded value but might be configurable
# in the future
qformer_compression_rate = 1
integer = audio_frames // compression_rate
remainder = audio_frames % compression_rate
def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: result = integer if remainder == 0 else integer + 1
return 10000
integer = result // qformer_compression_rate
remainder = result % qformer_compression_rate
# qformer compression
result = integer if remainder == 0 else integer + 1
def dummy_audio_for_phi4mm(audio_count: int) -> dict: return result
"""
Create dummy audio data for the Phi4MM model, which is used for profiling.
Args:
audio_count (int): Number of audio samples.
Returns: class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
dict: Dummy audio data.
"""
dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0)
return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
num_images = mm_counts.get("image", 0)
def dummy_image_for_phi4mm(width: int, height: int): image_tokens: list[str] = self.info.image_tokens[:num_images]
image = Image.new('RGB', (width, height), color='black') audio_tokens: list[str] = self.info.audio_tokens[:num_audios]
return image
return "".join(image_tokens + audio_tokens)
def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, def get_dummy_mm_data(
mm_counts: Mapping[str, int]) -> DummyData: self,
""" seq_len: int,
Create dummy sequence (input_ids) and audio data for the Phi4MM model, mm_counts: Mapping[str, int],
which is used for profiling. ) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
num_images = mm_counts.get("image", 0)
In this case, the sequence data is a bunch of 0s with a number of audio target_width, target_height = \
tokens that correspond to the audio embed size of the self.info.get_image_size_with_most_features()
_AUDIO_MAX_SOUNDFILE_SIZE.
Args: target_width, target_height = \
ctx (InputContext): Input context. self.info.get_image_size_with_most_features()
seq_len (int): Length of the sequence.
mm_counts (Mapping[str, int]): Multi-modal counts.
Returns:
Tuple: Dummy sequence data and dummy audio data.
"""
audio_count = mm_counts["audio"]
audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE,
DUMMY_SAMPLING_FREQUENCY)
audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(),
audio_frames)
image_count = mm_counts["image"]
dummy_image = get_max_dummy_image(ctx)
max_image_tokens = get_max_phi4mm_image_tokens(ctx)
total_image_tokens = image_count * max_image_tokens
if seq_len - audio_feature_size * audio_count - total_image_tokens < 0:
raise RuntimeError(
f"Phi4MM cannot process {audio_count} audios and {image_count}"
f"images in a prompt, please increase max_model_len to be at"
f" larger than "
f"{audio_feature_size * audio_count + total_image_tokens}"
" or reduce audio/image limit by --limit-mm-per-prompt.")
if audio_feature_size * audio_count > total_image_tokens:
seq_data = SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count),
(0, seq_len - audio_feature_size * audio_count),
)
mm_data = { mm_data = {
"audio": dummy_audio_for_phi4mm(audio_count), "image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"audio":
self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE,
num_audios=num_audios),
} }
else:
seq_data = SequenceData.from_prompt_token_counts(
(_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens),
(0, seq_len - total_image_tokens),
)
mm_data = {
"image": [dummy_image] * image_count,
}
return DummyData(seq_data, mm_data)
return mm_data
def input_mapper_for_phi4mm_audio(ctx: InputContext,
data: object) -> MultiModalKwargs:
"""
This function is used to create the MultiModalKwargs for the Phi4MM
(audio) model.
Specifically, for audio, we extract the audio features from the sound
file and create pairs of audio features and audio embed lengths (the
latter of which is used to repeat the audio placeholder token in the
input prompt IDs).
These pairs are used, downstream, in `_audio_features_to_embeddings`
(via `_process_audio_input`).
Note that the incoming audio data (each entry in `data`) is a tuple of
the audio data and the sampling frequency (e.g. from soundfile.read).
Args:
ctx (InputContext): Input context.
data (object): Audio data.
Returns:
MultiModalKwargs: Multi-modal inputs.
"""
if not isinstance(data, list):
data = [data]
if len(data) == 0:
return MultiModalKwargs()
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
audio, sf = audio_input
feature_extractor = audio_feature_extractor()
single_audio_features = feature_extractor.extract_features(audio, sf)
feat_stride = (1 if not hasattr(feature_extractor, "stride") else
feature_extractor.stride)
audio_frames = len(single_audio_features) * feat_stride
single_audio_embed_size = _compute_audio_embed_size(
ctx.get_hf_config(), audio_frames)
single_audio_feature_audio_len_pair = (
single_audio_features,
[single_audio_embed_size],
)
audio_features.append(single_audio_feature_audio_len_pair)
return MultiModalKwargs({"audio_features": audio_features})
def input_mapper_for_phi4mm_image(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]
# data: list of PIL images
if len(data) == 0:
return MultiModalKwargs()
hf_config = ctx.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
vision_encoder_name = SIGLIP_NAME
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
dynamic_hd_size = prepro_config['dynamic_hd']
vit_image_size = prepro_config['vit_image_size']
vit_patch_size = prepro_config['vit_patch_size']
image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size,
vit_patch_size)
return MultiModalKwargs({
"pixel_values":
image_input_dict["pixel_values"],
"image_sizes":
image_input_dict["image_sizes"],
"image_attention_mask":
image_input_dict["image_attention_mask"],
"num_img_tokens":
image_input_dict["num_img_tokens"],
})
class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
def cat_with_pad(tensors, dim, padding_value=0): def _get_data_parser(self) -> MultiModalDataParser:
""" feature_extractor = self.info.get_feature_extractor()
cat along dim, while pad to max for all other dims return MultiModalDataParser(target_sr=feature_extractor.sampling_rate,
""" audio_resample_method="scipy")
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in
tensors[1:]), "All tensors must have the same number of dimensions"
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] def _call_hf_processor(
out_size[dim] = sum(t.shape[dim] for t in tensors) self,
output = tensors[0].new_full(out_size, padding_value) prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
sr = self.info.get_feature_extractor().sampling_rate
if (audio_data := mm_data.get("audios", [])):
mm_data['audios'] = [(data, sr) for data in audio_data]
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs)
num_img_tokens = [
self.info.get_num_image_tokens(image_width=img_size[0],
image_height=img_size[1])
for img_size in processed_outputs["image_sizes"]
]
processed_outputs["num_img_tokens"] = num_img_tokens
index = 0 audio_features = processed_outputs['input_audio_embeds']
for t in tensors: feature_sizes = [
# Create a slice list where every dimension except dim is full slice self.info.get_audio_num_frames(len(audio), sr)
slices = [slice(0, t.shape[d]) for d in range(ndim)] for audio in audio_data
# Update only the concat dimension slice ]
slices[dim] = slice(index, index + t.shape[dim]) processed_outputs['input_audio_embeds'] = [
audio_features[idx, :size]
for idx, size in enumerate(feature_sizes)
]
output[slices] = t return processed_outputs
index += t.shape[dim]
return output def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
input_image_embeds=MultiModalFieldConfig.batched("image"),
image_attention_mask=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
num_img_tokens=MultiModalFieldConfig.batched("image"),
input_audio_embeds=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
image_tokens: list[str] = self.info.image_tokens # type: ignore
audio_tokens: list[str] = self.info.audio_tokens # type: ignore
feature_extractor = self.info.get_feature_extractor()
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
def get_image_replacement_phi4mm(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
return image_tokens
def get_audio_replacement_phi4mm(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
# TODO(Isotr0py): support embedding inputs
audio_len = audios.get_audio_length(item_idx)
audio_frames = self.info.get_audio_num_frames(
audio_len, feature_extractor.sampling_rate)
audio_embed_size = self.info._compute_audio_embed_size(
audio_frames)
audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
return audio_tokens
num_images = mm_items.get_count("image", strict=False)
num_audios = mm_items.get_count("audio", strict=False)
image_repl = [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_image_replacement_phi4mm,
) for image_token in image_tokens[:num_images]
]
audio_repl = [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_audio_replacement_phi4mm,
) for audio_token in audio_tokens[:num_audios]
]
return image_repl + audio_repl
@MULTIMODAL_REGISTRY.register_input_mapper("audio", @MULTIMODAL_REGISTRY.register_processor(
input_mapper_for_phi4mm_audio) Phi4MMMultiModalProcessor,
@MULTIMODAL_REGISTRY.register_input_mapper("image", info=Phi4MMProcessingInfo,
input_mapper_for_phi4mm_image) dummy_inputs=Phi4MMDummyInputsBuilder,
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( )
"audio", get_max_phi4mm_audio_tokens) class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"image", get_max_phi4mm_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm)
class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
SupportsV0Only):
""" """
Implements the Phi-4-multimodal-instruct model in vLLM. Implements the Phi-4-multimodal-instruct model in vLLM.
""" """
...@@ -1518,48 +991,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, ...@@ -1518,48 +991,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = get_sampler()
def _audio_features_to_embeddings(
self,
input_ids: torch.Tensor,
input_features: List[torch.Tensor],
audio_input_sizes: torch.Tensor,
audio_projection_mode: str,
) -> torch.Tensor:
"""
Convert audio features to embeddings, which are used as input to the
model (via `inputs_embeds`).
Args:
input_ids (torch.Tensor): Input IDs (the prompt in this case).
input_features (list[torch.Tensor]): Input features (the audio
embeddings).
audio_input_sizes (list[torch.Tensor]): Audio input sizes (the
audio embed lengths to use for padding the audio placeholder token
in the input prompt IDs).
"""
# The audio projection can either be a single linear or Sequential,
# so handle both cases
if isinstance(self.embed_tokens_extend.audio_projection,
nn.Sequential):
target_dtype = self.embed_tokens_extend.audio_projection[
0].bias.dtype
else:
target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype
audio_input = [
input.unsqueeze(0).to(target_dtype) for input in input_features
]
kwargs = {
"wte": self.model.embed_tokens,
'audio_projection_mode': audio_projection_mode
}
audio_embeddings = self.embed_tokens_extend(input_ids, audio_input,
audio_input_sizes,
**kwargs)
audio_embeddings = audio_embeddings.to(target_dtype)
return audio_embeddings
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: self, **kwargs: object) -> Optional[Phi4MMAudioInputs]:
...@@ -1574,7 +1006,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, ...@@ -1574,7 +1006,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
Returns: Returns:
Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
""" """
audio_features = kwargs.pop("audio_features", None) audio_features = kwargs.pop("input_audio_embeds", None)
audio_embeds = kwargs.pop("audio_embeds", None) audio_embeds = kwargs.pop("audio_embeds", None)
if audio_features is None and audio_embeds is None: if audio_features is None and audio_embeds is None:
...@@ -1586,7 +1018,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, ...@@ -1586,7 +1018,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
f"Got type: {type(audio_features)}") f"Got type: {type(audio_features)}")
return Phi4MMAudioFeatureInputs(type="audio_features", return Phi4MMAudioFeatureInputs(type="audio_features",
data=audio_features) data=flatten_bn(audio_features))
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)): if not isinstance(audio_embeds, (torch.Tensor, list)):
...@@ -1598,8 +1030,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, ...@@ -1598,8 +1030,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _process_audio_input(self, input_ids: torch.Tensor, def _process_audio_input(self, audio_input: Phi4MMAudioInputs,
audio_input: Phi4MMAudioInputs,
audio_projection_mode: str) -> NestedTensors: audio_projection_mode: str) -> NestedTensors:
""" """
Create the audio embeddings from the audio input, where the audio input Create the audio embeddings from the audio input, where the audio input
...@@ -1607,8 +1038,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, ...@@ -1607,8 +1038,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
created by `input_mapper_for_phi4mm_audio`. created by `input_mapper_for_phi4mm_audio`.
Args: Args:
input_ids (torch.Tensor): Input IDs (the prompt in this case,
before the audio token replication).
audio_input (Phi4MMAudioInputs): Audio input. audio_input (Phi4MMAudioInputs): Audio input.
Returns: Returns:
...@@ -1620,21 +1049,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, ...@@ -1620,21 +1049,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
audio_features = audio_input["data"] audio_features = audio_input["data"]
# (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example) # (e.g. multiple audios in the same example)
audio_feature = [i[0] for j in audio_features for i in j]
audio_feature_len = [i[1].item() for j in audio_features for i in j]
# Add the batch dim via `squeeze`
return self._audio_features_to_embeddings( dtype = next(self.embed_tokens_extend.parameters()).dtype
input_ids.unsqueeze(0), audio_embeds = [
audio_feature, self.embed_tokens_extend(
audio_feature_len, features.to(dtype),
audio_projection_mode, audio_projection_mode=audio_projection_mode,
).squeeze(0) ) for features in audio_features
]
return audio_embeds
def _parse_and_validate_image_input(self, def _parse_and_validate_image_input(self,
**kwargs: object) -> Optional[Dict]: **kwargs: object) -> Optional[Dict]:
pixel_values: Optional[Dict] = kwargs.get("pixel_values") input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
if pixel_values is None: if input_image_embeds is None:
return None return None
image_sizes = kwargs.get("image_sizes") image_sizes = kwargs.get("image_sizes")
...@@ -1643,23 +1071,24 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, ...@@ -1643,23 +1071,24 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
assert image_sizes is not None and image_attention_mask is not None\ assert image_sizes is not None and image_attention_mask is not None\
and num_img_tokens is not None, "Missing image inputs" and num_img_tokens is not None, "Missing image inputs"
if isinstance(pixel_values, list): if is_list_of(input_image_embeds, torch.Tensor):
assert pixel_values[0].dim() == 5, "Incorrect image inputs" assert all(p.dim() == 5
for p in input_image_embeds), "Incorrect image inputs"
# list len is batch_size. # list len is batch_size.
# each tensor has dimension: num_img_per_example, num_hd_patches, # each tensor has dimension: num_img_per_example, num_hd_patches,
# channels, height, width. # channels, height, width.
# need to pad along num_hd_patches. # need to pad along num_hd_patches.
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
pixel_values = cat_with_pad(pixel_values, dim=0) input_image_embeds = cat_with_pad(input_image_embeds, dim=0)
elif isinstance(pixel_values, torch.Tensor): elif isinstance(input_image_embeds, torch.Tensor):
# dimension: batch_size, num_img_per_example, num_hd_patches, # dimension: batch_size, num_img_per_example, num_hd_patches,
# channels, height, width. # channels, height, width.
# we flatten first 2 dims to make it a single large batch for # we flatten first 2 dims to make it a single large batch for
# SigLIP Encoder. # SigLIP Encoder.
assert pixel_values.dim() == 6, "Incorrect image inputs" assert input_image_embeds.dim() == 6, "Incorrect image inputs"
pixel_values = pixel_values.flatten(0, 1) input_image_embeds = input_image_embeds.flatten(0, 1)
else: else:
raise ValueError("Incorrect pixel_values inputs") raise ValueError("Incorrect input_image_embeds inputs")
if isinstance(image_attention_mask, list): if isinstance(image_attention_mask, list):
image_attention_mask = cat_with_pad(image_attention_mask, dim=0) image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
...@@ -1685,80 +1114,140 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, ...@@ -1685,80 +1114,140 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
else: else:
raise ValueError("Incorrect image_attention_mask inputs") raise ValueError("Incorrect image_attention_mask inputs")
return { return Phi4MMImagePixelInputs(
'pixel_values': pixel_values, type="pixel_values",
'image_sizes': image_sizes, data=input_image_embeds,
'image_attention_mask': image_attention_mask, image_sizes=image_sizes,
'num_img_tokens': num_img_tokens, image_attention_mask=image_attention_mask,
} num_img_tokens=num_img_tokens,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("input_image_embeds",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("input_audio_embeds",
"audio_embeds") and "audios" not in modalities:
modalities["audios"] = self._parse_and_validate_audio_input(
**kwargs)
return modalities
def _process_image_input(
self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]:
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
dtype = next(self.vision_encoder.parameters()).dtype
pixel_values = image_input['data'].to(dtype)
image_sizes = image_input['image_sizes']
image_attention_mask = image_input['image_attention_mask']
image_embeds = self.vision_encoder(pixel_values, image_sizes,
image_attention_mask)
return image_embeds
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
def merge_image_features_to_inputs_embeds( # The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
audio_projection_mode = 'speech'
for modality in modalities:
# make sure process images first
if modality == "images":
audio_projection_mode = "vision"
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += tuple(vision_embeddings)
if modality == "audios":
audio_input = modalities["audios"]
audio_embeddings = self._process_audio_input(
audio_input, audio_projection_mode=audio_projection_mode)
multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings
def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
image_set_tensors: List[torch.Tensor], ) -> torch.Tensor:
): inputs_embeds = self.model.embed_tokens(input_ids)
position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero( if multimodal_embeddings is not None:
as_tuple=True) inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
assert all([t.shape[0] == 1 for t in image_set_tensors [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
]), 'img_set_tensor should have shape (1, N_tokens, C)' return inputs_embeds
# Shape: (merged_N_tokens, C)
image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) def get_input_embeddings_v0(
image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to( self,
inputs_embeds.device) input_ids: torch.Tensor,
merged_embeds = inputs_embeds.index_put( image_input: Optional[Phi4MMImagePixelInputs] = None,
indices=position_tuple, audio_input: Optional[Phi4MMAudioFeatureInputs] = None,
values=image_set_tensor, ) -> torch.Tensor:
accumulate=False, audio_projection_mode = 'speech'
) inputs_embeds = self.get_input_embeddings(input_ids)
return merged_embeds if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID,
)
audio_projection_mode = 'vision'
if audio_input is not None:
audio_embeds = self._process_audio_input(
audio_input, audio_projection_mode=audio_projection_mode)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
audio_embeds,
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID,
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> torch.Tensor: ) -> torch.Tensor:
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
else:
# Each entry in this is a pair of audio_features and audio_embed # NOTE: In v1, inputs_embeds is always generated at model runner from
# lengths # `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs)
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
image_inputs = self._parse_and_validate_image_input(**kwargs)
if image_input is None and audio_input is None:
has_audio = audio_input is not None
has_image = image_inputs is not None
if has_audio:
audio_projection_mode = 'vision' if has_image else 'speech'
inputs_embeds = self._process_audio_input(
input_ids, audio_input, audio_projection_mode)
if has_image:
dtype = self.vision_encoder.img_processor.embeddings.\
patch_embedding.weight.dtype
pixel_values = image_inputs['pixel_values'].to(dtype)
image_sizes = image_inputs['image_sizes']
image_attention_mask = image_inputs['image_attention_mask']
image_set_tensors = self.vision_encoder(
pixel_values, image_sizes, image_attention_mask)
if not has_audio:
inputs_embeds = self.model.embed_tokens(input_ids)
inputs_embeds = self.merge_image_features_to_inputs_embeds(
input_ids, inputs_embeds, image_set_tensors)
if has_image or has_audio:
# multi-modal input, we have set inputs_embeds properly in
# previous steps
input_ids = None
else:
# text-only, we keep using original input_ids
inputs_embeds = None inputs_embeds = None
else:
inputs_embeds = self.get_input_embeddings_v0(
input_ids,
image_input=image_input,
audio_input=audio_input)
input_ids = None
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
......
...@@ -1159,8 +1159,11 @@ class AudioEmbedding(nn.Module): ...@@ -1159,8 +1159,11 @@ class AudioEmbedding(nn.Module):
input_embeds: torch.FloatTensor, input_embeds: torch.FloatTensor,
audio_attention_mask: torch.Tensor = None, audio_attention_mask: torch.Tensor = None,
audio_projection_mode: str = "speech", audio_projection_mode: str = "speech",
): ) -> torch.FloatTensor:
"""
arguments:
input_embeds: audio features (B, T, D) B: num audios in a sequence
"""
if self.freeze_audio_processor: if self.freeze_audio_processor:
with torch.no_grad(): with torch.no_grad():
audio_features, masks = self.encoder(input_embeds, audio_features, masks = self.encoder(input_embeds,
...@@ -1210,62 +1213,20 @@ class AudioEmbedding(nn.Module): ...@@ -1210,62 +1213,20 @@ class AudioEmbedding(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, audio_features: torch.FloatTensor,
input_embeds: torch.FloatTensor, audio_attention_mask: torch.Tensor = None,
audio_embed_sizes, audio_projection_mode: str = "speech",
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
arguments: arguments:
input_ids: input text ids (B, U) audio_features: audio features (T, D)
input_embeds: audio features (B, T, D) B: num audios in a sequence
returns:
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
""" """
assert input_embeds is not None and len(input_embeds) == len( audio_embeds = self.get_audio_features(
audio_embed_sizes) audio_features.unsqueeze(0),
audio_attention_mask=audio_attention_mask,
input_shape = input_ids.size() audio_projection_mode=audio_projection_mode,
input_ids = input_ids.view(-1, input_shape[-1]) )
return audio_embeds.squeeze(0)
with torch.no_grad():
positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(
as_tuple=False)
if not isinstance(input_embeds, list):
input_embeds = [input_embeds]
audio_projection_mode = kwargs.get("audio_projection_mode", "speech")
audio_set_tensor = [
self.get_audio_features(
input_embed, audio_projection_mode=audio_projection_mode)
for input_embed in input_embeds
]
with torch.no_grad():
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
if "wte" in kwargs:
# we use the token embedding layer from the huggingface model, this
# is REQUIRED to make sure we are using the loaded weights.
hidden_states = kwargs["wte"](input_ids)
else:
# otherwise, we use token embedding in pretrained mixformer from
# phi team
hidden_states = self.wte(input_ids)
if len(positions.tolist()) > 0:
assert sum(audio_embed_sizes) == len(
positions
), "please ensure the encoder outputs have the same length as"\
" defined in input_ids!"
idx = 0
for i in range(len(audio_embed_sizes)):
cnt = audio_embed_sizes[i]
assert audio_set_tensor[i].shape[0] == 1
hidden_states[
positions[idx, 0],
positions[idx, 1]:positions[idx, 1] + cnt,
] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to(
hidden_states.dtype).to(hidden_states.device))
idx += cnt
return hidden_states
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import base64 import base64
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Literal, Optional
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -43,7 +43,7 @@ class AudioPlugin(MultiModalPlugin): ...@@ -43,7 +43,7 @@ class AudioPlugin(MultiModalPlugin):
"There is no default maximum multimodal tokens") "There is no default maximum multimodal tokens")
def resample_audio( def resample_audio_librosa(
audio: npt.NDArray[np.floating], audio: npt.NDArray[np.floating],
*, *,
orig_sr: float, orig_sr: float,
...@@ -52,6 +52,55 @@ def resample_audio( ...@@ -52,6 +52,55 @@ def resample_audio(
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
def resample_audio_scipy(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
):
# lazy import scipy.signal, otherwise it will crash doc build.
import scipy.signal
if orig_sr > target_sr:
return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
elif orig_sr < target_sr:
return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
return audio
class AudioResampler:
"""Resample audio data to a target sample rate."""
def __init__(
self,
target_sr: Optional[float] = None,
method: Literal["librosa", "scipy"] = "librosa",
):
self.target_sr = target_sr
self.method = method
def resample(
self,
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
) -> npt.NDArray[np.floating]:
if self.target_sr is None:
raise RuntimeError("Audio resampling is not supported when "
"`target_sr` is not provided")
if self.method == "librosa":
return resample_audio_librosa(audio,
orig_sr=orig_sr,
target_sr=self.target_sr)
elif self.method == "scipy":
return resample_audio_scipy(audio,
orig_sr=orig_sr,
target_sr=self.target_sr)
else:
raise ValueError(f"Invalid resampling method: {self.method}. "
"Supported methods are 'librosa' and 'scipy'.")
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
Union) TypeVar, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -14,7 +14,7 @@ from typing_extensions import TypeAlias, TypeGuard, assert_never ...@@ -14,7 +14,7 @@ from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .audio import resample_audio from .audio import AudioResampler
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict, ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem) MultiModalFieldConfig, MultiModalKwargs, VideoItem)
...@@ -308,10 +308,18 @@ class MultiModalDataParser: ...@@ -308,10 +308,18 @@ class MultiModalDataParser:
items to the model's expected sampling rate. items to the model's expected sampling rate.
""" """
def __init__(self, *, target_sr: Optional[float] = None) -> None: def __init__(
self,
*,
target_sr: Optional[float] = None,
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
) -> None:
super().__init__() super().__init__()
self.target_sr = target_sr self.audio_resampler = AudioResampler(
target_sr=target_sr,
method=audio_resample_method,
)
def _is_embeddings( def _is_embeddings(
self, data: object self, data: object
...@@ -374,15 +382,8 @@ class MultiModalDataParser: ...@@ -374,15 +382,8 @@ class MultiModalDataParser:
if orig_sr is None: if orig_sr is None:
new_audio = audio new_audio = audio
else: else:
target_sr = self.target_sr new_audio = self.audio_resampler.resample(audio,
if target_sr is None: orig_sr=orig_sr)
raise RuntimeError(
"Audio resampling is not supported when "
"`target_sr` is not provided")
new_audio = resample_audio(audio,
orig_sr=orig_sr,
target_sr=target_sr)
new_audios.append(new_audio) new_audios.append(new_audio)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment