Unverified Commit c5eea3c8 authored by Yue Zhang's avatar Yue Zhang Committed by GitHub
Browse files

[Frontend] Support simpler image input format (#9478)

parent 85dc92fc
...@@ -388,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ...@@ -388,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text": "What about these two?" "text": "What about these two?"
}] }]
}], phi3v_model_config, phi3v_tokenizer) }], phi3v_model_config, phi3v_tokenizer)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)
...@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod ...@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from functools import lru_cache, partial from functools import lru_cache, partial
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Mapping, Optional, Tuple, TypeVar, Union, cast) Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -59,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ...@@ -59,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part.""" """The type of the content part."""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
}
"""
image_url: Required[str]
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
}
"""
audio_url: Required[str]
ChatCompletionContentPartParam: TypeAlias = Union[ ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam] CustomChatCompletionContentPartParam,
CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam, str]
class CustomChatCompletionMessageParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False):
...@@ -387,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) ...@@ -387,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
"""
Parses a given multi modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
Returns:
A tuple (part_type, content) where:
- part_type: Type of the part (e.g., 'text', 'image_url').
- content: Parsed content (e.g., text, image URL).
Raises:
ValueError: If the 'type' field is missing and no direct URL is found.
"""
assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
content = MM_PARSER_MAP[part_type](part)
# Special case for 'image_url.detail'
if part_type == "image_url" and part.get("detail") != "auto":
logger.warning("'image_url.detail' is currently not supported "
"and will be ignored.")
return part_type, content
# Handle missing 'type' but provided direct URL fields.
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
part)
return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url")
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
role: str, role: str,
...@@ -402,29 +492,28 @@ def _parse_chat_message_content_parts( ...@@ -402,29 +492,28 @@ def _parse_chat_message_content_parts(
has_image = False has_image = False
for part in parts: for part in parts:
part_type = part["type"] if isinstance(part, str): # Handle plain text parts
if part_type == "text": text = _TextParser(part)
text = _TextParser(part)["text"]
texts.append(text) texts.append(text)
elif part_type == "image_url": else: # Handle structured dictionary parts
image_url = _ImageParser(part)["image_url"] part_type, content = _parse_chat_message_content_mm_part(part)
if image_url.get("detail", "auto") != "auto": # if part_type is text/refusal/image_url/audio_url but
logger.warning( # content is empty, logg a warning and skip
"'image_url.detail' is currently not supported and " if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
"will be ignored.") logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
mm_parser.parse_image(image_url["url"]) continue
has_image = True
elif part_type == "audio_url": if part_type in ("text", "refusal"):
audio_url = _AudioParser(part)["audio_url"] texts.append(content)
elif part_type == "image_url":
mm_parser.parse_audio(audio_url["url"]) mm_parser.parse_image(content)
elif part_type == "refusal": has_image = True
text = _RefusalParser(part)["refusal"] elif part_type == "audio_url":
texts.append(text) mm_parser.parse_audio(content)
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
if keep_multimodal_content: if keep_multimodal_content:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment