Unverified Commit 7af553ea authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Abstract the logic for reading and writing media content (#11527)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 2c9b8ea2
...@@ -33,6 +33,7 @@ class MockModelConfig: ...@@ -33,6 +33,7 @@ class MockModelConfig:
hf_config = MockHFConfig() hf_config = MockHFConfig()
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
...@@ -2,7 +2,6 @@ import warnings ...@@ -2,7 +2,6 @@ import warnings
from typing import Optional from typing import Optional
import pytest import pytest
from PIL import Image
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input( ...@@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input(
image_data = mm_data.get("image") image_data = mm_data.get("image")
assert image_data is not None assert image_data is not None
if image_count == 1: assert isinstance(image_data, list) and len(image_data) == image_count
assert isinstance(image_data, Image.Image)
else:
assert isinstance(image_data, list) and len(image_data) == image_count
def test_parse_chat_messages_single_image( def test_parse_chat_messages_single_image(
......
...@@ -9,7 +9,7 @@ import pytest ...@@ -9,7 +9,7 @@ import pytest
from PIL import Image, ImageChops from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import (async_fetch_image, fetch_image, from vllm.multimodal.utils import (MediaConnector,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
...@@ -23,7 +23,12 @@ TEST_IMAGE_URLS = [ ...@@ -23,7 +23,12 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def url_images() -> Dict[str, Image.Image]: def url_images() -> Dict[str, Image.Image]:
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS} connector = MediaConnector()
return {
image_url: connector.fetch_image(image_url)
for image_url in TEST_IMAGE_URLS
}
def get_supported_suffixes() -> Tuple[str, ...]: def get_supported_suffixes() -> Tuple[str, ...]:
...@@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool: ...@@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str): async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url) connector = MediaConnector()
image_async = await async_fetch_image(image_url)
image_sync = connector.fetch_image(image_url)
image_async = await connector.fetch_image_async(image_url)
assert _image_equals(image_sync, image_async) assert _image_equals(image_sync, image_async)
...@@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str): ...@@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes()) @pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image], async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
image_url: str, suffix: str): image_url: str, suffix: str):
connector = MediaConnector()
url_image = url_images[image_url] url_image = url_images[image_url]
try: try:
...@@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], ...@@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image = base64.b64encode(f.read()).decode("utf-8") base64_image = base64.b64encode(f.read()).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}" data_url = f"data:{mime_type};base64,{base64_image}"
data_image_sync = fetch_image(data_url) data_image_sync = connector.fetch_image(data_url)
if _image_equals(url_image, Image.open(f)): if _image_equals(url_image, Image.open(f)):
assert _image_equals(url_image, data_image_sync) assert _image_equals(url_image, data_image_sync)
else: else:
pass # Lossy format; only check that image can be opened pass # Lossy format; only check that image can be opened
data_image_async = await async_fetch_image(data_url) data_image_async = await connector.fetch_image_async(data_url)
assert _image_equals(data_image_sync, data_image_async) assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str): async def test_fetch_image_local_files(image_url: str):
connector = MediaConnector()
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:
origin_image = fetch_image(image_url) local_connector = MediaConnector(allowed_local_media_path=temp_dir)
origin_image = connector.fetch_image(image_url)
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)), origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
quality=100, quality=100,
icc_profile=origin_image.info.get('icc_profile')) icc_profile=origin_image.info.get('icc_profile'))
image_async = await async_fetch_image( image_async = await local_connector.fetch_image_async(
f"file://{temp_dir}/{os.path.basename(image_url)}", f"file://{temp_dir}/{os.path.basename(image_url)}")
allowed_local_media_path=temp_dir) image_sync = local_connector.fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}")
image_sync = fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
# Check that the images are equal # Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox() assert not ImageChops.difference(image_sync, image_async).getbbox()
with pytest.raises(ValueError): with pytest.raises(ValueError, match="must be a subpath"):
await async_fetch_image( await local_connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}", f"file://{temp_dir}/../{os.path.basename(image_url)}")
allowed_local_media_path=temp_dir) with pytest.raises(RuntimeError, match="Cannot load local files"):
with pytest.raises(ValueError): await connector.fetch_image_async(
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}") f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError): with pytest.raises(ValueError, match="must be a subpath"):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}", local_connector.fetch_image(
allowed_local_media_path=temp_dir) f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError): with pytest.raises(RuntimeError, match="Cannot load local files"):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}") connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
......
...@@ -21,12 +21,10 @@ class AudioAsset: ...@@ -21,12 +21,10 @@ class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"] name: Literal["winning_call", "mary_had_lamb"]
@property @property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]: def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR) s3_prefix=ASSET_DIR)
y, sr = librosa.load(audio_path, sr=None) return librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr
@property @property
def url(self) -> str: def url(self) -> str:
......
...@@ -6,7 +6,7 @@ from collections import defaultdict, deque ...@@ -6,7 +6,7 @@ from collections import defaultdict, deque
from functools import lru_cache, partial from functools import lru_cache, partial
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) Literal, Optional, Tuple, TypeVar, Union, cast)
import jinja2.nodes import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
...@@ -23,6 +23,8 @@ from openai.types.chat import ( ...@@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam, from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam) ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
...@@ -31,11 +33,7 @@ from typing_extensions import Required, TypeAlias, TypedDict ...@@ -31,11 +33,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio, from vllm.multimodal.utils import MediaConnector
async_get_and_parse_image,
async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
...@@ -368,14 +366,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -368,14 +366,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {}) if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._items: List[_T] = [] self._items_by_modality = defaultdict[str, list[_T]](list)
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
return self._model_config return self._model_config
@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@staticmethod @staticmethod
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
...@@ -435,38 +436,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -435,38 +436,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else: else:
raise TypeError(f"Unknown modality: {modality}") raise TypeError(f"Unknown modality: {modality}")
@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def add(self, modality: ModalityStr, item: _T) -> Optional[str]: def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
""" """
Add a multi-modal item to the current prompt and returns the Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any. placeholder string to use, if any.
""" """
allowed_count = self._allowed_items.get(modality, 1) allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1 current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count: if current_count > allowed_count:
raise ValueError( raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in " f"At most {allowed_count} {modality}(s) may be provided in "
"one request.") "one request.")
self._consumed_items[modality] = current_count self._items_by_modality[modality].append(item)
self._items.append(item)
return self._placeholder_str(modality, current_count) return self._placeholder_str(modality, current_count)
...@@ -475,22 +457,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -475,22 +457,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise NotImplementedError raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]: def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None if self._items_by_modality:
return dict(self._items_by_modality)
return None
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self) return MultiModalContentParser(self)
class AsyncMultiModalItemTracker( class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]: async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items: if self._items_by_modality:
items = await asyncio.gather(*self._items) return {
return self._combine(items) modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
return None return None
...@@ -522,7 +508,7 @@ class BaseMultiModalContentParser(ABC): ...@@ -522,7 +508,7 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -537,31 +523,31 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -537,31 +523,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url, image = self._connector.fetch_image(image_url)
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)
placeholder = self._tracker.add("image", image) placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url) audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio) placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
input_audio_data = input_audio.get("data","") audio_data = input_audio.get("data", "")
input_audio_format = input_audio.get("format","") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
audio = get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio) return self.parse_audio(audio_url)
self._add_placeholder(placeholder)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url) video = self._connector.fetch_video(video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
...@@ -573,33 +559,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -573,33 +559,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image( image_coro = self._connector.fetch_image_async(image_url)
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)
placeholder = self._tracker.add("image", image_coro) placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url) audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro) placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
input_audio_data = input_audio.get("data","") audio_data = input_audio.get("data", "")
input_audio_format = input_audio.get("format","") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio_coro) return self.parse_audio(audio_url)
self._add_placeholder(placeholder)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url) video = self._connector.fetch_video_async(video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
...@@ -695,10 +679,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) ...@@ -695,10 +679,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
# Define a mapping from part types to their corresponding parsing functions. # Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, MM_PARSER_MAP: Dict[
Callable[[ChatCompletionContentPartParam], str,
Union[str, Dict[str,str]]]] = { Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text": "text":
lambda part: _TextParser(part).get("text", ""), lambda part: _TextParser(part).get("text", ""),
"image_url": "image_url":
...@@ -715,8 +702,7 @@ MM_PARSER_MAP: Dict[str, ...@@ -715,8 +702,7 @@ MM_PARSER_MAP: Dict[str,
def _parse_chat_message_content_mm_part( def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
Union[str, Dict[str, str]]]:
""" """
Parses a given multi-modal content part based on its type. Parses a given multi-modal content part based on its type.
...@@ -783,7 +769,7 @@ def _parse_chat_message_content_parts( ...@@ -783,7 +769,7 @@ def _parse_chat_message_content_parts(
*, *,
wrap_dicts: bool, wrap_dicts: bool,
) -> List[ConversationMessage]: ) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = [] content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser() mm_parser = mm_tracker.create_parser()
...@@ -814,7 +800,7 @@ def _parse_chat_message_content_part( ...@@ -814,7 +800,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser, mm_parser: BaseMultiModalContentParser,
*, *,
wrap_dicts: bool, wrap_dicts: bool,
) -> Optional[Union[str, Dict[str, str]]]: ) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True, """Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
...@@ -823,8 +809,7 @@ def _parse_chat_message_content_part( ...@@ -823,8 +809,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders. with multimodal placeholders.
""" """
if isinstance(part, str): # Handle plain text parts if isinstance(part, str): # Handle plain text parts
text = _TextParser(part) return part
return text
# Handle structured dictionary parts # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part) part_type, content = _parse_chat_message_content_mm_part(part)
...@@ -855,7 +840,7 @@ def _parse_chat_message_content_part( ...@@ -855,7 +840,7 @@ def _parse_chat_message_content_part(
return {'type': 'audio'} if wrap_dicts else None return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio": if part_type == "input_audio":
dict_content = cast(Dict[str, str], content) dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content) mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None return {'type': 'audio'} if wrap_dicts else None
......
import base64
from io import BytesIO
from pathlib import Path
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
from .base import MultiModalPlugin from .base import MediaIO, MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs from .inputs import AudioItem, MultiModalData, MultiModalKwargs
try: try:
...@@ -12,6 +16,11 @@ try: ...@@ -12,6 +16,11 @@ try:
except ImportError: except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment] librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
class AudioPlugin(MultiModalPlugin): class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data.""" """Plugin for audio data."""
...@@ -39,3 +48,28 @@ def resample_audio( ...@@ -39,3 +48,28 @@ def resample_audio(
target_sr: float, target_sr: float,
) -> npt.NDArray[np.floating]: ) -> npt.NDArray[np.floating]:
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)
def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format="WAV")
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
Optional, Sequence, Tuple, Type, TypeVar, Union) Optional, Sequence, Tuple, Type, TypeVar, Union)
from torch import nn from torch import nn
...@@ -118,7 +119,7 @@ class MultiModalPlugin(ABC): ...@@ -118,7 +119,7 @@ class MultiModalPlugin(ABC):
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
data: MultiModalData[Any], data: MultiModalData[Any],
mm_processor_kwargs: Optional[Dict[str, Any]], mm_processor_kwargs: Optional[dict[str, Any]],
) -> MultiModalKwargs: ) -> MultiModalKwargs:
""" """
Transform the data into a dictionary of model inputs using the Transform the data into a dictionary of model inputs using the
...@@ -254,10 +255,10 @@ class MultiModalPlaceholderMap: ...@@ -254,10 +255,10 @@ class MultiModalPlaceholderMap:
""" """
class IndexMap(NamedTuple): class IndexMap(NamedTuple):
src: List[int] src: list[int]
dest: List[int] dest: list[int]
src_ranges: List[range] src_ranges: list[range]
""" """
The indices of the multi-modal embeddings that will replace the The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``. corresponding placeholder embeddings pointed to by ``dest_ranges``.
...@@ -268,7 +269,7 @@ class MultiModalPlaceholderMap: ...@@ -268,7 +269,7 @@ class MultiModalPlaceholderMap:
The total number of flattened multi-modal embeddings. The total number of flattened multi-modal embeddings.
""" """
dest_ranges: List[range] dest_ranges: list[range]
""" """
The indices of the placeholder embeddings that will be replaced by the The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings. multimodal embeddings.
...@@ -288,7 +289,7 @@ class MultiModalPlaceholderMap: ...@@ -288,7 +289,7 @@ class MultiModalPlaceholderMap:
@classmethod @classmethod
def from_seq_group( def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range cls, seq_group: "SequenceGroupMetadata", positions: range
) -> Tuple[Optional[MultiModalDataDict], Dict[str, ) -> Tuple[Optional[MultiModalDataDict], dict[str,
"MultiModalPlaceholderMap"]]: "MultiModalPlaceholderMap"]]:
""" """
Returns the multi-modal items that intersect with the portion of a Returns the multi-modal items that intersect with the portion of a
...@@ -376,9 +377,9 @@ class MultiModalPlaceholderMap: ...@@ -376,9 +377,9 @@ class MultiModalPlaceholderMap:
def append_items_from_seq_group( def append_items_from_seq_group(
self, self,
positions: range, positions: range,
multi_modal_items: List[_T], multi_modal_items: list[_T],
multi_modal_placeholders: Sequence[PlaceholderRange], multi_modal_placeholders: Sequence[PlaceholderRange],
) -> List[_T]: ) -> list[_T]:
""" """
Adds the multi-modal items that intersect ```positions`` to this Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items. placeholder map and returns the intersecting items.
...@@ -454,3 +455,22 @@ class MultiModalPlaceholderMap: ...@@ -454,3 +455,22 @@ class MultiModalPlaceholderMap:
return MultiModalPlaceholderMap.IndexMap(src=src_indices, return MultiModalPlaceholderMap.IndexMap(src=src_indices,
dest=dest_indices) dest=dest_indices)
class MediaIO(ABC, Generic[_T]):
@abstractmethod
def load_bytes(self, data: bytes) -> _T:
raise NotImplementedError
@abstractmethod
def load_base64(self, media_type: str, data: str) -> _T:
"""
List of media types:
https://www.iana.org/assignments/media-types/media-types.xhtml
"""
raise NotImplementedError
@abstractmethod
def load_file(self, filepath: Path) -> _T:
raise NotImplementedError
import base64
from functools import lru_cache from functools import lru_cache
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import torch import torch
...@@ -9,7 +12,7 @@ from vllm.logger import init_logger ...@@ -9,7 +12,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .base import MultiModalPlugin from .base import MediaIO, MultiModalPlugin
from .inputs import ImageItem, MultiModalData, MultiModalKwargs from .inputs import ImageItem, MultiModalData, MultiModalKwargs
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -96,3 +99,39 @@ def rescale_image_size(image: Image.Image, ...@@ -96,3 +99,39 @@ def rescale_image_size(image: Image.Image,
if transpose >= 0: if transpose >= 0:
image = image.transpose(Image.Transpose(transpose)) image = image.transpose(Image.Transpose(transpose))
return image return image
class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, *, image_mode: str = "RGB") -> None:
super().__init__()
self.image_mode = image_mode
def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data))
image.load()
return image.convert(self.image_mode)
def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return image.convert(self.image_mode)
def encode_base64(
self,
media: Image.Image,
*,
image_format: str = "JPEG",
) -> str:
image = media
with BytesIO() as buffer:
image = image.convert(self.image_mode)
image.save(buffer, image_format)
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')
This diff is collapsed.
from functools import lru_cache import base64
from functools import lru_cache, partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import cv2 import cv2
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from PIL import Image
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of from vllm.utils import PlaceholderModule, is_list_of
from .base import MultiModalData from .base import MediaIO, MultiModalData
from .image import ImagePlugin from .image import ImageMediaIO, ImagePlugin
from .inputs import MultiModalKwargs, VideoItem from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
try:
import decord
except ImportError:
decord = PlaceholderModule("decord") # type: ignore[assignment]
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor) cached_get_video_processor = lru_cache(get_video_processor)
...@@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray, ...@@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray,
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...] sampled_frames = frames[frame_indices, ...]
return sampled_frames return sampled_frames
class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__(
self,
image_io: ImageMediaIO,
*,
num_frames: int = 32,
) -> None:
super().__init__()
self.image_io = image_io
self.num_frames = num_frames
def load_bytes(self, data: bytes) -> npt.NDArray:
vr = decord.VideoReader(BytesIO(data), num_threads=1)
total_frame_num = len(vr)
num_frames = self.num_frames
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = list(range(0, total_frame_num))
return vr.get_batch(frame_idx).asnumpy()
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
if media_type.lower() == "video/jpeg":
load_frame = partial(
self.image_io.load_base64,
"image/jpeg",
)
return np.stack([
np.array(load_frame(frame_data))
for frame_data in data.split(",")
])
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> npt.NDArray:
with filepath.open("rb") as f:
data = f.read()
return self.load_bytes(data)
def encode_base64(
self,
media: npt.NDArray,
*,
video_format: str = "JPEG",
) -> str:
video = media
if video_format == "JPEG":
encode_frame = partial(
self.image_io.encode_base64,
image_format=video_format,
)
return ",".join(
encode_frame(Image.fromarray(frame)) for frame in video)
msg = "Only JPEG format is supported for now."
raise NotImplementedError(msg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment