Unverified Commit 2c1bd848 authored by Yang Fan's avatar Yang Fan Committed by GitHub
Browse files

[Model][VLM] Add Qwen2.5-Omni model support (thinker only) (#15130)


Signed-off-by: default avatarfyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarXiong Wang <wangxiongts@163.com>
parent 5c912120
......@@ -1040,6 +1040,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `Qwen2_5OmniThinkerForConditionalGeneration`
* Qwen2.5-Omni
* T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup>
* `Qwen/Qwen2.5-Omni-7B`
*
* ✅︎
* ✅︎\*
- * `SkyworkR1VChatModel`
* Skywork-R1V-38B
* T + I
......@@ -1109,6 +1116,14 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
:::
:::{note}
To use Qwen2.5-Omni, you have to install a fork of Hugging Face Transformers library from source via
`pip install git+https://github.com/BakerBunker/transformers.git@qwen25omni`.
Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
`--mm-processor-kwargs '{"use_audio_in_video": True}'`.
:::
### Pooling Models
See [this page](pooling-models) for more information on how to use pooling models.
......
......@@ -130,6 +130,36 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
)
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
model_name = "Qwen/Qwen2.5-Omni-7B"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join([
"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
])
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
......@@ -182,6 +212,7 @@ model_example_map = {
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,
"ultravox": run_ultravox,
"whisper": run_whisper,
}
......
# Qwen2.5-Omni Offline Inference Examples
This folder provides several example scripts on how to inference Qwen2.5-Omni offline.
## Thinker Only
```bash
# Audio + image + video
python examples/offline_inference/qwen2_5_omni/only_thinker.py -q mixed_modalities
# Read vision and audio inputs from a single video file
# NOTE: V1 engine does not support interleaved modalities yet.
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video
# Multiple audios
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q multi_audios
```
This script will run the thinker part of Qwen2.5-Omni, and generate text response.
You can also test Qwen2.5-Omni on a single modality:
```bash
# Process audio inputs
python examples/offline_inference/audio_language.py --model-type qwen2_5_omni
# Process image inputs
python examples/offline_inference/vision_language.py --modality image --model-type qwen2_5_omni
# Process video inputs
python examples/offline_inference/vision_language.py --modality video --model-type qwen2_5_omni
```
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen2.5-Omni (thinker only).
"""
from typing import NamedTuple
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.utils import FlexibleArgumentParser
class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")
def get_mixed_modalities_query() -> QueryResult:
question = ("What is recited in the audio? "
"What is the content of this image? Why is this video funny?")
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n")
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio":
AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image":
ImageAsset("cherry_blossom").pil_image.convert("RGB"),
"video":
VideoAsset(name="sample_demo_1.mp4",
num_frames=16).np_ndarrays,
},
},
limit_mm_per_prompt={
"audio": 1,
"image": 1,
"video": 1
},
)
def get_use_audio_in_video_query() -> QueryResult:
question = ("Describe the content of the video, "
"then convert what the baby say into text.")
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n")
asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. "
"Please launch this example with "
"`VLLM_USE_V1=0`.")
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": asset.np_ndarrays,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={
"audio": 1,
"video": 1
},
)
def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n")
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
],
},
},
limit_mm_per_prompt={
"audio": 2,
},
)
query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
}
def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B"
query_result = query_map[args.query_type]()
llm = LLM(model=model_name,
max_model_len=5632,
max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
outputs = llm.generate(query_result.inputs,
sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'audio language models')
parser.add_argument('--query-type',
'-q',
type=str,
default="mixed_modalities",
choices=query_map.keys(),
help='Query type.')
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
args = parser.parse_args()
main(args)
......@@ -941,6 +941,42 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
)
# Qwen2.5-Omni
def run_qwen2_5_omni(questions: list[str], modality: str):
model_name = "Qwen/Qwen2.5-Omni-7B"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": [1],
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
if modality == "image":
placeholder = "<|IMAGE|>"
elif modality == "video":
placeholder = "<|VIDEO|>"
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")
prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n"
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -1010,6 +1046,7 @@ model_example_map = {
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
"qwen2_5_omni": run_qwen2_5_omni,
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
}
......
......@@ -139,6 +139,23 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen2_5_omni": VLMTestInfo(
models=["Qwen/Qwen2.5-Omni-7B"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO
),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
......
......@@ -280,6 +280,7 @@ def _test_processing_correctness_mistral(
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-7B",
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
......
......@@ -362,6 +362,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501
min_transformers_version="4.52"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
......
......@@ -2,7 +2,7 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal
from typing import Literal, Optional
import cv2
import numpy as np
......@@ -10,8 +10,15 @@ import numpy.typing as npt
from huggingface_hub import hf_hub_download
from PIL import Image
from vllm.utils import PlaceholderModule
from .base import get_cache_dir
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
@lru_cache
def download_video_asset(filename: str) -> str:
......@@ -85,3 +92,12 @@ class VideoAsset:
video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames)
return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path = download_video_asset(self.name)
return librosa.load(video_path, sr=sampling_rate)[0]
......@@ -506,6 +506,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|IMAGE|><|vision_end|>"
if model_type == "molmo":
return ""
if model_type == "aria":
......@@ -521,7 +523,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|audio|>"
if model_type == "phi4mm":
return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
if model_type == "qwen2_audio":
if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "minicpmo":
......@@ -530,6 +532,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "video":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|VIDEO|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)"
if model_type.startswith("llava"):
......
......@@ -747,7 +747,7 @@ def compute_hash() -> str:
variables, ensure that it is included in the factors list if
it affects the computation graph. For example, different values
of VLLM_PP_LAYER_PARTITION will generate different computation
graphs, so it is included in the factors list. The env vars that
graphs, so it is included in the factors list. The env vars that
affect the choice of different kernels or attention backends should
also be included in the factors list.
"""
......
......@@ -988,8 +988,9 @@ class MRotaryEmbedding(RotaryEmbedding):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@staticmethod
@classmethod
def get_input_positions(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
......@@ -997,6 +998,8 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts: Optional[List[float]],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
......@@ -1006,7 +1009,7 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts
llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
cls.get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
......@@ -1014,12 +1017,52 @@ class MRotaryEmbedding(RotaryEmbedding):
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
return llm_positions.tolist(), mrope_position_delta
@staticmethod
@classmethod
def get_input_positions_tensor(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: List[float],
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config):
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=context_len,
seq_len=seq_len,
)
@classmethod
def _vl_get_input_positions_tensor(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
......@@ -1037,11 +1080,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second = getattr(hf_config.vision_config,
"tokens_per_second", 1.0)
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
......@@ -1121,6 +1159,226 @@ class MRotaryEmbedding(RotaryEmbedding):
return llm_positions, mrope_position_delta
@classmethod
def _omni_get_input_positions_tensor(
cls,
input_tokens: List[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config = hf_config.thinker_config
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
second_per_grid_ts = [1] * video_grid_thw.shape[0]
audio_idx = 0
video_idx = 0
image_idx = 0
new_src_item: list[int] = []
llm_pos_ids_list: list[torch.Tensor] = []
idx = 0
while idx < len(src_item):
new_src_item_len = len(new_src_item)
start_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if src_item[idx] not in [
audio_token_id, video_token_id, image_token_id
]:
if src_item[idx] == vision_end_token_id and use_audio_in_video:
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx],
dtype=torch.long).expand(3, -1)
llm_pos_ids_list.append(llm_pos_ids)
elif src_item[idx] == audio_token_id:
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
new_src_item.extend([audio_token_id] * place_num)
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
llm_pos_ids_list.append(llm_pos_ids)
audio_idx += 1
elif src_item[idx] == image_token_id:
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second).long()
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
grid_ws)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
new_src_item.extend([video_token_id] * vision_seqlen)
video_idx += 1
else:
# read audio from video
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2)
grid_t = video_grid_thw[video_idx][0]
grid_h = video_grid_thw[video_idx][1]
grid_w = video_grid_thw[video_idx][2]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) *
second_per_grid_ts[video_idx] *
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
new_src_item.extend([audio_start_token_id])
start_idx -= 1
llm_pos_ids_list.extend([
torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
] * 1)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: List[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
new_src_item.extend([video_token_id] *
vision_ntoken_per_chunk)
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
start_idx + 1, video_idx, spatial_merge_size, t_chunk,
grid_hs, grid_ws).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len) * [audio_token_id])
audio_start_idx = start_idx if len(
audio_llm_pos_ids_list
) == 0 else audio_llm_pos_ids_list[-1][0].item()
if min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (torch.arange(
min(t_ntoken_per_chunk, pure_audio_len -
added_audio_len)).expand(3, -1) +
audio_start_idx + 1).split(
1, dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(t_ntoken_per_chunk,
pure_audio_len - added_audio_len)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
if added_audio_len < pure_audio_len:
new_src_item.extend(
(pure_audio_len - added_audio_len) * [audio_token_id])
audio_llm_pos_ids_list = (
torch.arange(pure_audio_len - added_audio_len).expand(
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
llm_pos_ids_list.extend([
torch.tensor(
[llm_pos_ids_list[-1].max() + 1] * 3).unsqueeze(1)
] * 1)
new_src_item.extend([audio_end_token_id])
audio_idx += 1
video_idx += 1
# move to the next token
idx += len(new_src_item) - new_src_item_len
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
mrope_position_delta = torch.cat(llm_pos_ids_list,
dim=1).max() + 1 - len(src_item)
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@staticmethod
def _get_llm_pos_ids_for_vision(
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: List[int],
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
) -> torch.Tensor:
llm_pos_ids_list = []
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
len(t_index), -1, llm_grid_w).flatten())
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
len(t_index), llm_grid_h, -1).flatten())
t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
interval: int) -> List[List[int]]:
ranges: List[List[int]] = [[]
for _ in range((max(lst) // interval) + 1)]
for num in lst:
index = num // interval
ranges[index].append(num)
return ranges
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
......@@ -1144,6 +1402,58 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_delta + seq_len,
).expand(3, -1)
@classmethod
def omni_get_updates_use_audio_in_video(
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[List[int], torch.Tensor],
video_second_per_grid_t: float,
) -> List[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id = thinker_config.audio_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
grid_t = video_grid_thw[0]
grid_h = video_grid_thw[1]
grid_w = video_grid_thw[2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t) * video_second_per_grid_t *
tokens_per_second).long()
t_index_split_chunk = cls._split_list_into_ranges(
t_index, t_ntoken_per_chunk)
updates = [audio_start_token_id]
added_audio_len = 0
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
spatial_merge_size**2)
updates.extend([video_token_id] * vision_ntoken_per_chunk)
audio_chunk_size = min(t_ntoken_per_chunk,
audio_len - added_audio_len)
updates.extend(audio_chunk_size * [audio_token_id])
added_audio_len += audio_chunk_size
if added_audio_len < audio_len:
updates.extend((audio_len - added_audio_len) * [audio_token_id])
updates.extend([audio_end_token_id])
return updates
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
......
......@@ -583,21 +583,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
mm_input_by_modality = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("pixel_values_videos",
"video_embeds") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
if input_key in ("pixel_values", "image_embeds"
) and "image" not in mm_input_by_modality:
mm_input_by_modality[
"image"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_videos", "video_embeds"
) and "video" not in mm_input_by_modality:
mm_input_by_modality[
"video"] = self._parse_and_validate_video_input(**kwargs)
return modalities
return mm_input_by_modality
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
......@@ -848,8 +848,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
if not mm_input_by_modality:
return None
# The result multimodal_embeddings is tuple of tensors, with each
......@@ -858,14 +859,13 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += tuple(vision_embeddings)
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
if modality == "video":
video_embeddings = self._process_video_pixels(multimodal_input)
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
......
This diff is collapsed.
......@@ -38,13 +38,14 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
......@@ -195,6 +196,23 @@ class Qwen2_5_VisionMLP(nn.Module):
return x_down
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
dist.all_gather(gathered_tensors, local_tensor)
gathered_tensors_split = [
torch.split(tensor, hidden_size // tp_size, -1)
for tensor in gathered_tensors
]
ordered_tensors = [
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
]
result_tensor = torch.cat(ordered_tensors, dim=-1)
return result_tensor
class Qwen2_5_VisionAttention(nn.Module):
def __init__(
......@@ -214,10 +232,14 @@ class Qwen2_5_VisionAttention(nn.Module):
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
......@@ -236,7 +258,8 @@ class Qwen2_5_VisionAttention(nn.Module):
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = tensor_model_parallel_all_gather(qkv)
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
self.tp_size)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)
......@@ -694,9 +717,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
......@@ -952,20 +975,20 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return video_embeds.split(sizes.tolist())
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
mm_input_by_modality = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("pixel_values_videos",
"video_embeds") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
if input_key in ("pixel_values", "image_embeds"
) and "image" not in mm_input_by_modality:
mm_input_by_modality[
"image"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_videos", "video_embeds"
) and "video" not in mm_input_by_modality:
mm_input_by_modality[
"video"] = self._parse_and_validate_video_input(**kwargs)
return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
......@@ -973,8 +996,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
if not mm_input_by_modality:
return None
# The result multimodal_embeddings is tuple of tensors, with each
......@@ -983,14 +1007,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
if modality == "video":
video_embeddings = self._process_video_input(multimodal_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
......
......@@ -200,6 +200,7 @@ _MULTIMODAL_MODELS = {
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
# [Encoder-decoder]
......
......@@ -84,7 +84,7 @@ def replace_linear_class(
) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
......
......@@ -320,7 +320,8 @@ class MultiModalFlatField(BaseMultiModalField):
:func:`MultiModalFieldConfig.flat`
:func:`MultiModalFieldConfig.flat_from_sizes`
"""
slices: Sequence[slice]
slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
dim: int = 0
def build_elems(
self,
......@@ -329,7 +330,10 @@ class MultiModalFlatField(BaseMultiModalField):
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data[s]) for s in self.slices]
if not is_list_of(self.slices, slice, check="all"):
assert isinstance(data, torch.Tensor), \
"torch.Tensor is required for multiple slices"
return [field_factory(data[cast(slice, s)]) for s in self.slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
......@@ -338,10 +342,16 @@ class MultiModalFlatField(BaseMultiModalField):
# - produce exactly same result as `torch.concat(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].contiguous()
first_shape = batch[0].shape
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch)
def _expect_same_shape(tensor: torch.Tensor):
return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]
first_shape = _expect_same_shape(batch[0])
if all(_expect_same_shape(elem) == first_shape for elem in batch):
return torch.concat(batch, dim=self.dim)
assert self.dim == 0, "dim == 0 is required for nested list"
return [e for elem in batch for e in elem]
......@@ -398,7 +408,9 @@ class MultiModalFieldConfig:
)
@staticmethod
def flat(modality: str, slices: Sequence[slice]):
def flat(modality: str,
slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
dim: int = 0):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
......@@ -406,8 +418,10 @@ class MultiModalFieldConfig:
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, a slice that is used to extract
the data corresponding to it.
slices: For each multi-modal item, a slice (dim=0) or a tuple of
slices (dim>0) that is used to extract the data corresponding
to it.
dim: The dimension to extract data, default to 0.
Example:
......@@ -423,14 +437,33 @@ class MultiModalFieldConfig:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
.. code-block::
Given:
slices: [
(slice(None), slice(0, 3)),
(slice(None), slice(3, 7)),
(slice(None), slice(7, 9))]
dim: 1
Input:
Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]
Output:
Element 1: [[A],[A],[A]]
Element 2: [[B],[B],[B],[B]]
Element 3: [[C],[C]]
"""
return MultiModalFieldConfig(
field=MultiModalFlatField(slices=slices),
field=MultiModalFlatField(slices=slices, dim=dim),
modality=modality,
)
@staticmethod
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
def flat_from_sizes(modality: str,
size_per_item: torch.Tensor,
dim: int = 0):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
......@@ -440,6 +473,7 @@ class MultiModalFieldConfig:
keyword argument.
slices: For each multi-modal item, the size of the slice that
is used to extract the data corresponding to it.
dim: The dimension to slice, default to 0.
Example:
......@@ -455,6 +489,21 @@ class MultiModalFieldConfig:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
.. code-block::
Given:
slices: [3, 4, 2]
dim: 1
Input:
Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]
Output:
Element 1: [[A],[A],[A]]
Element 2: [[B],[B],[B],[B]]
Element 3: [[C],[C]]
See also:
:func:`MultiModalFieldConfig.flat`
......@@ -465,12 +514,11 @@ class MultiModalFieldConfig:
f"but found shape: {size_per_item.shape}")
slice_idxs = [0, *accumulate(size_per_item)]
slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(size_per_item))
]
slices = [(slice(None, None, None), ) * dim +
(slice(slice_idxs[i], slice_idxs[i + 1]), )
for i in range(len(size_per_item))]
return MultiModalFieldConfig.flat(modality, slices)
return MultiModalFieldConfig.flat(modality, slices, dim=dim)
@staticmethod
def shared(modality: str, batch_size: int):
......
......@@ -222,8 +222,7 @@ def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
def uses_mrope(config: PretrainedConfig) -> bool:
"""Detect if the model with this config uses M-ROPE."""
def _uses_mrope(config: PretrainedConfig) -> bool:
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return False
......@@ -231,6 +230,24 @@ def uses_mrope(config: PretrainedConfig) -> bool:
return "mrope_section" in rope_scaling
def uses_mrope(config: PretrainedConfig) -> bool:
"""Detect if the model with this config uses M-ROPE."""
return _uses_mrope(config) or thinker_uses_mrope(config)
def thinker_uses_mrope(config: PretrainedConfig) -> bool:
"""Detect if the model contains a thinker config and it uses M-ROPE."""
thinker_config = getattr(config, "thinker_config", None)
if thinker_config is None:
return False
thinker_text_config = getattr(thinker_config, "text_config", None)
if thinker_text_config is None:
return False
return uses_mrope(thinker_text_config)
def is_encoder_decoder(config: PretrainedConfig) -> bool:
"""Detect if the model with this config is used as an encoder/decoder."""
text_config = getattr(config, "text_config", None)
......@@ -740,6 +757,11 @@ def get_hf_text_config(config: PretrainedConfig):
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
elif hasattr(config, "thinker_config"):
# TODO(suyang.fy): Refactor code.
# For Qwen2.5-Omni, change hf_text_config to
# thinker_config.text_config.
return config.thinker_config.text_config
else:
return config
......
......@@ -111,6 +111,55 @@ def cached_processor_from_config(
)
def get_feature_extractor(
processor_name: str,
*args: Any,
trust_remote_code: bool = False,
**kwargs: Any,
):
"""Load an audio feature extractor for the given model name
via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoFeatureExtractor
from transformers.feature_extraction_utils import FeatureExtractionMixin
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the feature extractor. If the feature "
"extractor is a custom extractor not yet available in the "
"HuggingFace transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return cast(FeatureExtractionMixin, feature_extractor)
cached_get_feature_extractor = lru_cache(get_feature_extractor)
def cached_feature_extractor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_feature_extractor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
)
def get_image_processor(
processor_name: str,
*args: Any,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment