Unverified Commit c5eea3c8 authored by Yue Zhang's avatar Yue Zhang Committed by GitHub
Browse files

[Frontend] Support simpler image input format (#9478)

parent 85dc92fc
......@@ -388,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text": "What about these two?"
}]
}], phi3v_model_config, phi3v_tokenizer)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)
......@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
Mapping, Optional, Tuple, TypeVar, Union, cast)
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
# yapf conflicts with isort for this block
# yapf: disable
......@@ -59,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
}
"""
image_url: Required[str]
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
}
"""
audio_url: Required[str]
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam]
CustomChatCompletionContentPartParam,
CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam, str]
class CustomChatCompletionMessageParam(TypedDict, total=False):
......@@ -387,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
"""
Parses a given multi modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
Returns:
A tuple (part_type, content) where:
- part_type: Type of the part (e.g., 'text', 'image_url').
- content: Parsed content (e.g., text, image URL).
Raises:
ValueError: If the 'type' field is missing and no direct URL is found.
"""
assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
content = MM_PARSER_MAP[part_type](part)
# Special case for 'image_url.detail'
if part_type == "image_url" and part.get("detail") != "auto":
logger.warning("'image_url.detail' is currently not supported "
"and will be ignored.")
return part_type, content
# Handle missing 'type' but provided direct URL fields.
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
part)
return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url")
def _parse_chat_message_content_parts(
role: str,
......@@ -402,29 +492,28 @@ def _parse_chat_message_content_parts(
has_image = False
for part in parts:
part_type = part["type"]
if part_type == "text":
text = _TextParser(part)["text"]
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
texts.append(text)
elif part_type == "image_url":
image_url = _ImageParser(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
mm_parser.parse_image(image_url["url"])
has_image = True
elif part_type == "audio_url":
audio_url = _AudioParser(part)["audio_url"]
mm_parser.parse_audio(audio_url["url"])
elif part_type == "refusal":
text = _RefusalParser(part)["refusal"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url but
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
continue
if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if keep_multimodal_content:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment