Unverified Commit 873f384a authored by Yuhao Yao's avatar Yuhao Yao Committed by GitHub
Browse files

[feat] Add detail in image_data (#8596)

parent b01eeb80
...@@ -30,8 +30,10 @@ import re ...@@ -30,8 +30,10 @@ import re
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from typing_extensions import Literal
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.utils import read_system_prompt_from_file from sglang.srt.utils import ImageData, read_system_prompt_from_file
class SeparatorStyle(IntEnum): class SeparatorStyle(IntEnum):
...@@ -91,7 +93,7 @@ class Conversation: ...@@ -91,7 +93,7 @@ class Conversation:
video_token: str = "<video>" video_token: str = "<video>"
audio_token: str = "<audio>" audio_token: str = "<audio>"
image_data: Optional[List[str]] = None image_data: Optional[List[ImageData]] = None
video_data: Optional[List[str]] = None video_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
stop_token_ids: Optional[int] = None stop_token_ids: Optional[int] = None
...@@ -381,9 +383,9 @@ class Conversation: ...@@ -381,9 +383,9 @@ class Conversation:
"""Append a new message.""" """Append a new message."""
self.messages.append([role, message]) self.messages.append([role, message])
def append_image(self, image: str): def append_image(self, image: str, detail: Literal["auto", "low", "high"]):
"""Append a new image.""" """Append a new image."""
self.image_data.append(image) self.image_data.append(ImageData(url=image, detail=detail))
def append_video(self, video: str): def append_video(self, video: str):
"""Append a new video.""" """Append a new video."""
...@@ -627,7 +629,9 @@ def generate_chat_conv( ...@@ -627,7 +629,9 @@ def generate_chat_conv(
real_content = image_token + real_content real_content = image_token + real_content
else: else:
real_content += image_token real_content += image_token
conv.append_image(content.image_url.url) conv.append_image(
content.image_url.url, content.image_url.detail
)
elif content.type == "video_url": elif content.type == "video_url":
real_content += video_token real_content += video_token
conv.append_video(content.video_url.url) conv.append_video(content.video_url.url)
......
...@@ -9,6 +9,8 @@ import logging ...@@ -9,6 +9,8 @@ import logging
import jinja2 import jinja2
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.utils import ImageData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ============================================================================ # ============================================================================
...@@ -140,7 +142,12 @@ def process_content_for_template_format( ...@@ -140,7 +142,12 @@ def process_content_for_template_format(
chunk_type = chunk.get("type") chunk_type = chunk.get("type")
if chunk_type == "image_url": if chunk_type == "image_url":
image_data.append(chunk["image_url"]["url"]) image_data.append(
ImageData(
url=chunk["image_url"]["url"],
detail=chunk["image_url"].get("detail", "auto"),
)
)
if chunk.get("modalities"): if chunk.get("modalities"):
modalities.append(chunk.get("modalities")) modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility # Normalize to simple 'image' type for template compatibility
......
...@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef ...@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.utils import ImageData
# Handle serialization of Image for pydantic # Handle serialization of Image for pydantic
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -45,7 +46,7 @@ class SessionParams: ...@@ -45,7 +46,7 @@ class SessionParams:
# Type definitions for multimodal input data # Type definitions for multimodal input data
# Individual data item types for each modality # Individual data item types for each modality
ImageDataInputItem = Union[Image, str, Dict] ImageDataInputItem = Union[Image, str, ImageData, Dict]
AudioDataInputItem = Union[str, Dict] AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict] VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item # Union type for any multimodal data item
......
...@@ -44,6 +44,7 @@ import traceback ...@@ -44,6 +44,7 @@ import traceback
import warnings import warnings
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec from importlib.util import find_spec
...@@ -84,6 +85,7 @@ from torch.library import Library ...@@ -84,6 +85,7 @@ from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils._contextlib import _DecoratorContextManager from torch.utils._contextlib import _DecoratorContextManager
from triton.runtime.cache import FileCacheManager from triton.runtime.cache import FileCacheManager
from typing_extensions import Literal
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
...@@ -736,9 +738,18 @@ def load_audio( ...@@ -736,9 +738,18 @@ def load_audio(
return audio return audio
@dataclass
class ImageData:
url: str
detail: Optional[Literal["auto", "low", "high"]] = "auto"
def load_image( def load_image(
image_file: Union[Image.Image, str, bytes], image_file: Union[Image.Image, str, ImageData, bytes],
) -> tuple[Image.Image, tuple[int, int]]: ) -> tuple[Image.Image, tuple[int, int]]:
if isinstance(image_file, ImageData):
image_file = image_file.url
image = image_size = None image = image_size = None
if isinstance(image_file, Image.Image): if isinstance(image_file, Image.Image):
image = image_file image = image_file
...@@ -762,7 +773,7 @@ def load_image( ...@@ -762,7 +773,7 @@ def load_image(
elif isinstance(image_file, str): elif isinstance(image_file, str):
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
else: else:
raise ValueError(f"Invalid image: {image}") raise ValueError(f"Invalid image: {image_file}")
return image, image_size return image, image_size
......
...@@ -85,7 +85,7 @@ class TestTemplateContentFormatDetection(CustomTestCase): ...@@ -85,7 +85,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
# Check that image_data was extracted # Check that image_data was extracted
self.assertEqual(len(image_data), 1) self.assertEqual(len(image_data), 1)
self.assertEqual(image_data[0], "http://example.com/image.jpg") self.assertEqual(image_data[0].url, "http://example.com/image.jpg")
# Check that content was normalized # Check that content was normalized
expected_content = [ expected_content = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment