Unverified Commit 01dfb5e9 authored by Chenheli Hua's avatar Chenheli Hua Committed by GitHub
Browse files

[Frontend] User-provided uuids for medias in chat. (RFC #22044) (#23449)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Signed-off-by: default avatarChenheli Hua <huachenheli@outlook.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.me>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoger Wang <hey@rogerw.me>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 03dd652c
...@@ -215,19 +215,19 @@ When loading RGBA images (images with transparency), vLLM converts them to RGB f ...@@ -215,19 +215,19 @@ When loading RGBA images (images with transparency), vLLM converts them to RGB f
```python ```python
from vllm import LLM from vllm import LLM
# Default white background (no configuration needed) # Default white background (no configuration needed)
llm = LLM(model="llava-hf/llava-1.5-7b-hf") llm = LLM(model="llava-hf/llava-1.5-7b-hf")
# Custom black background for dark theme # Custom black background for dark theme
llm = LLM( llm = LLM(
model="llava-hf/llava-1.5-7b-hf", model="llava-hf/llava-1.5-7b-hf",
media_io_kwargs={"image": {"rgba_background_color": [0, 0, 0]}} media_io_kwargs={"image": {"rgba_background_color": [0, 0, 0]}}
) )
# Custom brand color background (e.g., blue) # Custom brand color background (e.g., blue)
llm = LLM( llm = LLM(
model="llava-hf/llava-1.5-7b-hf", model="llava-hf/llava-1.5-7b-hf",
media_io_kwargs={"image": {"rgba_background_color": [0, 0, 255]}} media_io_kwargs={"image": {"rgba_background_color": [0, 0, 255]}}
) )
``` ```
...@@ -388,7 +388,7 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd ...@@ -388,7 +388,7 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
## Online Serving ## Online Serving
Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests.
!!! important !!! important
A chat template is **required** to use Chat Completions API. A chat template is **required** to use Chat Completions API.
...@@ -438,7 +438,13 @@ Then, you can use the OpenAI client as follows: ...@@ -438,7 +438,13 @@ Then, you can use the OpenAI client as follows:
# NOTE: The prompt formatting with the image token `<image>` is not needed # NOTE: The prompt formatting with the image token `<image>` is not needed
# since the prompt will be processed automatically by the API server. # since the prompt will be processed automatically by the API server.
{"type": "text", "text": "What’s in this image?"}, {"type": "text", "text": "What’s in this image?"},
{"type": "image_url", "image_url": {"url": image_url}}, {
"type": "image_url",
"image_url": {
url": image_url
},
"uuid": image_url # Optional
},
], ],
}], }],
) )
...@@ -454,8 +460,20 @@ Then, you can use the OpenAI client as follows: ...@@ -454,8 +460,20 @@ Then, you can use the OpenAI client as follows:
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": "What are the animals in these images?"}, {"type": "text", "text": "What are the animals in these images?"},
{"type": "image_url", "image_url": {"url": image_url_duck}}, {
{"type": "image_url", "image_url": {"url": image_url_lion}}, "type": "image_url",
"image_url": {
"url": image_url_duck
},
"uuid": image_url_duck # Optional
},
{
"type": "image_url",
"image_url": {
"url": image_url_lion
},
"uuid": image_url_lion # Optional
},
], ],
}], }],
) )
...@@ -522,6 +540,7 @@ Then, you can use the OpenAI client as follows: ...@@ -522,6 +540,7 @@ Then, you can use the OpenAI client as follows:
"video_url": { "video_url": {
"url": video_url "url": video_url
}, },
"uuid": video_url # Optional
}, },
], ],
}], }],
...@@ -613,6 +632,7 @@ Then, you can use the OpenAI client as follows: ...@@ -613,6 +632,7 @@ Then, you can use the OpenAI client as follows:
"data": audio_base64, "data": audio_base64,
"format": "wav" "format": "wav"
}, },
"uuid": audio_url # Optional
}, },
], ],
}], }],
...@@ -642,6 +662,7 @@ Alternatively, you can pass `audio_url`, which is the audio counterpart of `imag ...@@ -642,6 +662,7 @@ Alternatively, you can pass `audio_url`, which is the audio counterpart of `imag
"audio_url": { "audio_url": {
"url": audio_url "url": audio_url
}, },
"uuid": audio_url # Optional
}, },
], ],
}], }],
...@@ -695,7 +716,8 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ...@@ -695,7 +716,8 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
model = "llava-hf/llava-1.5-7b-hf" model = "llava-hf/llava-1.5-7b-hf"
embeds = { embeds = {
"type": "image_embeds", "type": "image_embeds",
"image_embeds": f"{base64_image_embedding}" "image_embeds": f"{base64_image_embedding}",
"uuid": image_url # Optional
} }
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V) # Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
...@@ -706,6 +728,7 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ...@@ -706,6 +728,7 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
"image_embeds": f"{base64_image_embedding}" , # Required "image_embeds": f"{base64_image_embedding}" , # Required
"image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct "image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct
}, },
"uuid": image_url # Optional
} }
model = "openbmb/MiniCPM-V-2_6" model = "openbmb/MiniCPM-V-2_6"
embeds = { embeds = {
...@@ -714,6 +737,7 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ...@@ -714,6 +737,7 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
"image_embeds": f"{base64_image_embedding}" , # Required "image_embeds": f"{base64_image_embedding}" , # Required
"image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6 "image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6
}, },
"uuid": image_url # Optional
} }
chat_completion = client.chat.completions.create( chat_completion = client.chat.completions.create(
messages=[ messages=[
......
...@@ -436,3 +436,132 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, ...@@ -436,3 +436,132 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
) )
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize(
"image_urls",
[TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
indirect=True)
async def test_completions_with_image(
client: openai.AsyncOpenAI,
model_name: str,
image_urls: list[str],
):
for image_url in image_urls:
chat_completion = await client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role":
"user",
"content": [
{
"type": "text",
"text": "Describe this image.",
},
{
"type": "image_url",
"image_url": {
"url": image_url,
}
},
],
},
],
model=model_name,
)
assert chat_completion.choices[0].message.content is not None
assert isinstance(chat_completion.choices[0].message.content, str)
assert len(chat_completion.choices[0].message.content) > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize(
"image_urls",
[TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
indirect=True)
async def test_completions_with_image_with_uuid(
client: openai.AsyncOpenAI,
model_name: str,
image_urls: list[str],
):
for image_url in image_urls:
chat_completion = await client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role":
"user",
"content": [
{
"type": "text",
"text": "Describe this image.",
},
{
"type": "image_url",
"image_url": {
"url": image_url,
},
"uuid": image_url
},
],
},
],
model=model_name,
)
assert chat_completion.choices[0].message.content is not None
assert isinstance(chat_completion.choices[0].message.content, str)
assert len(chat_completion.choices[0].message.content) > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize(
"image_urls",
[TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
indirect=True)
async def test_completions_with_image_with_incorrect_uuid_format(
client: openai.AsyncOpenAI,
model_name: str,
image_urls: list[str],
):
for image_url in image_urls:
chat_completion = await client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role":
"user",
"content": [
{
"type": "text",
"text": "Describe this image.",
},
{
"type": "image_url",
"image_url": {
"url": image_url,
"incorrect_uuid_key": image_url,
},
"also_incorrect_uuid_key": image_url,
},
],
},
],
model=model_name,
)
assert chat_completion.choices[0].message.content is not None
assert isinstance(chat_completion.choices[0].message.content, str)
assert len(chat_completion.choices[0].message.content) > 0
This diff is collapsed.
...@@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict ...@@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalUUIDDict)
from vllm.multimodal.utils import MediaConnector from vllm.multimodal.utils import MediaConnector
# yapf: disable # yapf: disable
from vllm.transformers_utils.chat_templates import ( from vllm.transformers_utils.chat_templates import (
...@@ -72,6 +73,11 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): ...@@ -72,6 +73,11 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
type: Required[Literal["audio_url"]] type: Required[Literal["audio_url"]]
"""The type of the content part.""" """The type of the content part."""
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
...@@ -83,6 +89,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): ...@@ -83,6 +89,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
""" """
type: Required[Literal["image_embeds"]] type: Required[Literal["image_embeds"]]
"""The type of the content part.""" """The type of the content part."""
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class VideoURL(TypedDict, total=False): class VideoURL(TypedDict, total=False):
...@@ -97,6 +108,11 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False): ...@@ -97,6 +108,11 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
type: Required[Literal["video_url"]] type: Required[Literal["video_url"]]
"""The type of the content part.""" """The type of the content part."""
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class PILImage(BaseModel): class PILImage(BaseModel):
...@@ -118,6 +134,11 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): ...@@ -118,6 +134,11 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
""" """
image_pil: Required[PILImage] image_pil: Required[PILImage]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
...@@ -131,6 +152,11 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): ...@@ -131,6 +152,11 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
""" """
image_url: Required[str] image_url: Required[str]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
...@@ -155,6 +181,11 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): ...@@ -155,6 +181,11 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
""" """
video_url: Required[str] video_url: Required[str]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class CustomThinkCompletionContentParam(TypedDict, total=False): class CustomThinkCompletionContentParam(TypedDict, total=False):
...@@ -567,6 +598,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -567,6 +598,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._items_by_modality = defaultdict[str, list[_T]](list) self._items_by_modality = defaultdict[str, list[_T]](list)
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
...@@ -591,10 +623,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -591,10 +623,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def mm_processor(self): def mm_processor(self):
return self.mm_registry.create_processor(self.model_config) return self.mm_registry.create_processor(self.model_config)
def add(self, modality: ModalityStr, item: _T) -> Optional[str]: def add(
self, modality: ModalityStr, item: _T, uuid: Optional[str] = None
) -> Optional[str]:
""" """
Add a multi-modal item to the current prompt and returns the Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any. placeholder string to use, if any.
An optional uuid can be added which serves as a unique identifier of the
media.
""" """
input_modality = modality.replace("_embeds", "") input_modality = modality.replace("_embeds", "")
num_items = len(self._items_by_modality[modality]) + 1 num_items = len(self._items_by_modality[modality]) + 1
...@@ -602,9 +639,35 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -602,9 +639,35 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self.mm_processor.validate_num_items(input_modality, num_items) self.mm_processor.validate_num_items(input_modality, num_items)
self._items_by_modality[modality].append(item) self._items_by_modality[modality].append(item)
self._uuids_by_modality[modality].append(uuid)
return self.model_cls.get_placeholder_str(modality, num_items) return self.model_cls.get_placeholder_str(modality, num_items)
def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
if not self._items_by_modality:
return None
mm_uuids = {}
uuids_by_modality = dict(self._uuids_by_modality)
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed"
)
if "image_embeds" in uuids_by_modality:
image_embeds_uuids = uuids_by_modality["image_embeds"]
if len(image_embeds_uuids) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}"
)
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
if "audio" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
if "video" in uuids_by_modality:
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
return mm_uuids
@abstractmethod @abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
raise NotImplementedError raise NotImplementedError
...@@ -697,29 +760,35 @@ class BaseMultiModalContentParser(ABC): ...@@ -697,29 +760,35 @@ class BaseMultiModalContentParser(ABC):
return dict(self._placeholder_storage) return dict(self._placeholder_storage)
@abstractmethod @abstractmethod
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_image_embeds( def parse_image_embeds(
self, image_embeds: Union[str, dict[str, str]] self,
image_embeds: Union[str, dict[str, str]],
uuid: Optional[str] = None,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_image_pil(self, image_pil: Image.Image) -> None: def parse_image_pil(
self, image_pil: Image.Image, uuid: Optional[str] = None
) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_input_audio(self, input_audio: InputAudio) -> None: def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -734,49 +803,55 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -734,49 +803,55 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
) )
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
image = self._connector.fetch_image(image_url) image = self._connector.fetch_image(image_url)
placeholder = self._tracker.add("image", image) placeholder = self._tracker.add("image", image, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_embeds( def parse_image_embeds(
self, image_embeds: Union[str, dict[str, str]] self,
image_embeds: Union[str, dict[str, str]],
uuid: Optional[str] = None,
) -> None: ) -> None:
if isinstance(image_embeds, dict): if isinstance(image_embeds, dict):
embeds = { embeds = {
k: self._connector.fetch_image_embedding(v) k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items() for k, v in image_embeds.items()
} }
placeholder = self._tracker.add("image_embeds", embeds) placeholder = self._tracker.add("image_embeds", embeds, uuid)
if isinstance(image_embeds, str): if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds) embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding) placeholder = self._tracker.add("image_embeds", embedding, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None: def parse_image_pil(
placeholder = self._tracker.add("image", image_pil) self, image_pil: Image.Image, uuid: Optional[str] = None
) -> None:
placeholder = self._tracker.add("image", image_pil, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
audio = self._connector.fetch_audio(audio_url) audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio) placeholder = self._tracker.add("audio", audio, uuid)
self._add_placeholder("audio", placeholder) self._add_placeholder("audio", placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None: def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
) -> None:
audio_data = input_audio.get("data", "") audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
return self.parse_audio(audio_url) return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
video = self._connector.fetch_video(video_url=video_url) video = self._connector.fetch_video(video_url=video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video, uuid)
self._add_placeholder("video", placeholder) self._add_placeholder("video", placeholder)
...@@ -790,14 +865,16 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -790,14 +865,16 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
) )
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) image_coro = self._connector.fetch_image_async(image_url)
placeholder = self._tracker.add("image", image_coro) placeholder = self._tracker.add("image", image_coro, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_embeds( def parse_image_embeds(
self, image_embeds: Union[str, dict[str, str]] self,
image_embeds: Union[str, dict[str, str]],
uuid: Optional[str] = None,
) -> None: ) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
...@@ -812,33 +889,37 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -812,33 +889,37 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
embedding = self._connector.fetch_image_embedding(image_embeds) embedding = self._connector.fetch_image_embedding(image_embeds)
future.set_result(embedding) future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future) placeholder = self._tracker.add("image_embeds", future, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None: def parse_image_pil(
self, image_pil: Image.Image, uuid: Optional[str] = None
) -> None:
future: asyncio.Future[Image.Image] = asyncio.Future() future: asyncio.Future[Image.Image] = asyncio.Future()
future.set_result(image_pil) future.set_result(image_pil)
placeholder = self._tracker.add("image", future) placeholder = self._tracker.add("image", future, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url) audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro) placeholder = self._tracker.add("audio", audio_coro, uuid)
self._add_placeholder("audio", placeholder) self._add_placeholder("audio", placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None: def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
) -> None:
audio_data = input_audio.get("data", "") audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
return self.parse_audio(audio_url) return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
video = self._connector.fetch_video_async(video_url=video_url) video = self._connector.fetch_video_async(video_url=video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video, uuid)
self._add_placeholder("video", placeholder) self._add_placeholder("video", placeholder)
...@@ -1177,30 +1258,36 @@ def _parse_chat_message_content_part( ...@@ -1177,30 +1258,36 @@ def _parse_chat_message_content_part(
else: else:
return str_content return str_content
# For media items, if a user has provided one, use it. Otherwise, insert
# a placeholder empty uuid.
uuid = part.get("uuid", None)
if uuid is not None:
uuid = str(uuid)
modality = None modality = None
if part_type == "image_pil": if part_type == "image_pil":
image_content = cast(Image.Image, content) image_content = cast(Image.Image, content)
mm_parser.parse_image_pil(image_content) mm_parser.parse_image_pil(image_content, uuid)
modality = "image" modality = "image"
elif part_type in ("image_url", "input_image"): elif part_type in ("image_url", "input_image"):
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_image(str_content) mm_parser.parse_image(str_content, uuid)
modality = "image" modality = "image"
elif part_type == "image_embeds": elif part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content) content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content) mm_parser.parse_image_embeds(content, uuid)
modality = "image" modality = "image"
elif part_type == "audio_url": elif part_type == "audio_url":
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_audio(str_content) mm_parser.parse_audio(str_content, uuid)
modality = "audio" modality = "audio"
elif part_type == "input_audio": elif part_type == "input_audio":
dict_content = cast(InputAudio, content) dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content) mm_parser.parse_input_audio(dict_content, uuid)
modality = "audio" modality = "audio"
elif part_type == "video_url": elif part_type == "video_url":
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_video(str_content) mm_parser.parse_video(str_content, uuid)
modality = "video" modality = "video"
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
...@@ -1288,7 +1375,11 @@ def parse_chat_messages( ...@@ -1288,7 +1375,11 @@ def parse_chat_messages(
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: ) -> tuple[
list[ConversationMessage],
Optional[MultiModalDataDict],
Optional[MultiModalUUIDDict],
]:
conversation: list[ConversationMessage] = [] conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer) mm_tracker = MultiModalItemTracker(model_config, tokenizer)
...@@ -1308,7 +1399,7 @@ def parse_chat_messages( ...@@ -1308,7 +1399,7 @@ def parse_chat_messages(
_postprocess_messages(conversation) _postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data() return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
def parse_chat_messages_futures( def parse_chat_messages_futures(
...@@ -1316,7 +1407,11 @@ def parse_chat_messages_futures( ...@@ -1316,7 +1407,11 @@ def parse_chat_messages_futures(
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: ) -> tuple[
list[ConversationMessage],
Awaitable[Optional[MultiModalDataDict]],
Optional[MultiModalUUIDDict],
]:
conversation: list[ConversationMessage] = [] conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
...@@ -1336,7 +1431,7 @@ def parse_chat_messages_futures( ...@@ -1336,7 +1431,7 @@ def parse_chat_messages_futures(
_postprocess_messages(conversation) _postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data() return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
def apply_hf_chat_template( def apply_hf_chat_template(
......
...@@ -796,7 +796,7 @@ class LLM: ...@@ -796,7 +796,7 @@ class LLM:
# NOTE: _parse_chat_message_content_parts() currently doesn't # NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in # handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it. # the chat message parsing for it.
conversation, mm_data = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
msgs, msgs,
model_config, model_config,
tokenizer, tokenizer,
...@@ -826,6 +826,9 @@ class LLM: ...@@ -826,6 +826,9 @@ class LLM:
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs prompt["mm_processor_kwargs"] = mm_processor_kwargs
......
...@@ -929,7 +929,7 @@ class OpenAIServing: ...@@ -929,7 +929,7 @@ class OpenAIServing:
tokenizer, tokenizer,
model_config=model_config, model_config=model_config,
) )
conversation, mm_data_future = parse_chat_messages_futures( conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
messages, messages,
model_config, model_config,
tokenizer, tokenizer,
...@@ -1006,6 +1006,10 @@ class OpenAIServing: ...@@ -1006,6 +1006,10 @@ class OpenAIServing:
prompt_token_ids=prompt_inputs["prompt_token_ids"]) prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None: if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
if request.mm_processor_kwargs is not None: if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
......
...@@ -276,13 +276,23 @@ class InputPreprocessor: ...@@ -276,13 +276,23 @@ class InputPreprocessor:
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply( mm_input = mm_processor.apply(
prompt, prompt,
mm_data, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_hash_overrides=mm_hash_overrides,
) )
mm_hashes = mm_input["mm_hashes"]
# Validate that all mm items have a string as their hash
if not contains_only_strings(mm_hashes):
raise ValueError(
f"mm_hashes must contain only strings, got: {mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method.")
return mm_input
async def _process_multimodal_async( async def _process_multimodal_async(
self, self,
...@@ -310,13 +320,23 @@ class InputPreprocessor: ...@@ -310,13 +320,23 @@ class InputPreprocessor:
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply( mm_input = mm_processor.apply(
prompt, prompt,
mm_data, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_hash_overrides=mm_hash_overrides,
) )
mm_hashes = mm_input["mm_hashes"]
# Validate that all mm items have a string as their hash
if not contains_only_strings(mm_hashes):
raise ValueError(
f"mm_hashes must contain only strings, got: {mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method.")
return mm_input
def _process_embeds( def _process_embeds(
self, self,
...@@ -953,3 +973,15 @@ class InputPreprocessor: ...@@ -953,3 +973,15 @@ class InputPreprocessor:
def clear_cache(self) -> None: def clear_cache(self) -> None:
if self.mm_processor_cache is not None: if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache() self.mm_processor_cache.clear_cache()
# Helper function to validate that a nested dictionary contains
# only strings or list of strings as the leaf values.
def contains_only_strings(obj: object):
if isinstance(obj, str):
return True
if isinstance(obj, list):
return all(isinstance(x, str) for x in obj)
if isinstance(obj, dict):
return all(contains_only_strings(v) for v in obj.values())
return False
...@@ -174,9 +174,10 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): ...@@ -174,9 +174,10 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {} tokenization_kwargs = tokenization_kwargs or {}
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else mm_hashes = self._hash_mm_items(mm_items,
self._hash_mm_items(mm_items, hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs)) tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data) mm_processed_data = BatchFeature(image_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment