Unverified Commit f19a9204 authored by Yury Sulsky's avatar Yury Sulsky Committed by GitHub
Browse files

Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)


Co-authored-by: default avatarYury Sulsky <ysulsky@tesla.com>
parent c23a7072
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Querying Qwen-VL"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply() # Run this first.\n",
"\n",
"model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n",
"chat_template = \"qwen2-vl\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"# Lets create a prompt.\n",
"\n",
"from io import BytesIO\n",
"import requests\n",
"from PIL import Image\n",
"\n",
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
"from sglang.srt.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",
" BytesIO(\n",
" requests.get(\n",
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
" ).content\n",
" )\n",
")\n",
"\n",
"conv = chat_templates[chat_template].copy()\n",
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
"conv.append_message(conv.roles[1], \"\")\n",
"conv.image_data = [image]\n",
"\n",
"print(conv.get_prompt())\n",
"image"
]
},
{
"cell_type": "markdown",
"id": "3",
"metadata": {},
"source": [
"## Query via the offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"from sglang import Engine\n",
"\n",
"llm = Engine(\n",
" model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
"print(out[\"text\"])"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Query via the offline Engine API, but send precomputed embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"# Compute the image embeddings using Huggingface.\n",
"\n",
"from transformers import AutoProcessor\n",
"from transformers import Qwen2_5_VLForConditionalGeneration\n",
"\n",
"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
"vision = (\n",
" Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"processed_prompt = processor(\n",
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
")\n",
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
"precomputed_features = vision(\n",
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
")\n",
"\n",
"mm_item = dict(\n",
" modality=\"IMAGE\",\n",
" image_grid_thws=processed_prompt[\"image_grid_thw\"],\n",
" precomputed_features=precomputed_features,\n",
")\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
"print(out[\"text\"])"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"custom_cell_magics": "kql",
"encoding": "# -*- coding: utf-8 -*-"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
...@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
ImageDataItem,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
...@@ -150,9 +151,9 @@ class Engine(EngineBase): ...@@ -150,9 +151,9 @@ class Engine(EngineBase):
# See also python/sglang/srt/utils.py:load_image for more details. # See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[ image_data: Optional[
Union[ Union[
List[List[Union[Image, str]]], List[List[ImageDataItem]],
List[Union[Image, str]], List[ImageDataItem],
Union[Image, str], ImageDataItem,
] ]
] = None, ] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
...@@ -221,9 +222,9 @@ class Engine(EngineBase): ...@@ -221,9 +222,9 @@ class Engine(EngineBase):
# See also python/sglang/srt/utils.py:load_image for more details. # See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[ image_data: Optional[
Union[ Union[
List[List[Union[Image, str]]], List[List[ImageDataItem]],
List[Union[Image, str]], List[ImageDataItem],
Union[Image, str], ImageDataItem,
] ]
] = None, ] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
......
...@@ -40,6 +40,10 @@ class SessionParams: ...@@ -40,6 +40,10 @@ class SessionParams:
replace: Optional[bool] = None replace: Optional[bool] = None
AudioDataItem = Union[str, Dict]
ImageDataItem = Union[Image, str, Dict]
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
...@@ -55,10 +59,10 @@ class GenerateReqInput: ...@@ -55,10 +59,10 @@ class GenerateReqInput:
# - List of lists of images (multiple images per request) # - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details. # See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[ image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]] Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
] = None ] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string. # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
# The sampling_params. See descriptions below. # The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id. # The request id.
......
...@@ -368,13 +368,13 @@ def general_mm_embed_routine( ...@@ -368,13 +368,13 @@ def general_mm_embed_routine(
input_ids: torch.Tensor, input_ids: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
language_model: nn.Module, language_model: nn.Module,
image_data_embedding_func: Callable[ image_data_embedding_func: Optional[
[List[MultimodalDataItem]], torch.Tensor Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
audio_data_embedding_func: Callable[ audio_data_embedding_func: Optional[
[List[MultimodalDataItem]], torch.Tensor Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
placeholder_tokens: dict[Modality, List[int]] = None, placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -389,7 +389,6 @@ def general_mm_embed_routine( ...@@ -389,7 +389,6 @@ def general_mm_embed_routine(
forwarded hidden states forwarded hidden states
""" """
assert hasattr(language_model, "get_input_embeddings") assert hasattr(language_model, "get_input_embeddings")
embed_tokens = language_model.get_input_embeddings() embed_tokens = language_model.get_input_embeddings()
if ( if (
......
...@@ -3,16 +3,16 @@ import concurrent.futures ...@@ -3,16 +3,16 @@ import concurrent.futures
import dataclasses import dataclasses
import multiprocessing as mp import multiprocessing as mp
import os import os
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional, Union
import numpy as np import numpy as np
import PIL
import torch import torch
from PIL import Image from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image from sglang.srt.utils import encode_video, load_audio, load_image
...@@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput: ...@@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput:
input_text: str input_text: str
# frames loaded from image and video, in given order # frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
# audios # audios
audios: Optional[list[np.ndarray]] = None audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
def normalize(self): def normalize(self):
for field_name in ["image_sizes", "images", "audios"]: for field_name in ["images", "audios"]:
field = getattr(self, field_name, None) field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0: if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None) setattr(self, field_name, None)
...@@ -40,12 +40,32 @@ class MultimodalSpecialTokens: ...@@ -40,12 +40,32 @@ class MultimodalSpecialTokens:
video_token: Optional[str] = None video_token: Optional[str] = None
audio_token: Optional[str] = None audio_token: Optional[str] = None
def collect(self) -> list[str]: image_token_regex: Optional[re.Pattern] = None
return [ video_token_regex: Optional[re.Pattern] = None
token audio_token_regex: Optional[re.Pattern] = None
for token in [self.image_token, self.video_token, self.audio_token]
if token def __post_init__(self):
if self.image_token_regex is None and self.image_token is not None:
self.image_token_regex = re.compile(re.escape(self.image_token))
if self.video_token_regex is None and self.video_token is not None:
self.video_token_regex = re.compile(re.escape(self.video_token))
if self.audio_token_regex is None and self.audio_token is not None:
self.audio_token_regex = re.compile(re.escape(self.audio_token))
def collect(self) -> re.Pattern:
tokens = [
self.image_token_regex,
self.video_token_regex,
self.audio_token_regex,
] ]
patterns = []
flags = 0
for t in tokens:
if t is not None:
patterns.append(t.pattern)
flags |= t.flags
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
return re.compile(combined, flags)
class BaseMultimodalProcessor(ABC): class BaseMultimodalProcessor(ABC):
...@@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC): ...@@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC):
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
): ):
"""Static method that can be pickled for multiprocessing""" """Static method that can be pickled for multiprocessing"""
if isinstance(data, dict):
return MultimodalDataItem.from_dict(data)
if isinstance(data, MultimodalDataItem):
return data
try: try:
if is_audio: if is_audio:
return load_audio(data) return load_audio(data)
...@@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC): ...@@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC):
image_index, audio_index = 0, 0 image_index, audio_index = 0, 0
for text_part in text_parts: for text_part in text_parts:
if text_part == multimodal_tokens.image_token: if (
multimodal_tokens.image_token_regex
and multimodal_tokens.image_token_regex.match(text_part)
):
data = image_data[image_index] data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:") is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index] estimated_frames = estimated_frames_list[image_index]
...@@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC): ...@@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC):
) )
task_info.append((Modality.IMAGE, data, frame_count_limit)) task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1 image_index += 1
elif text_part == multimodal_tokens.audio_token: elif (
multimodal_tokens.audio_token_regex
and multimodal_tokens.audio_token_regex.match(text_part)
):
data = audio_data[audio_index] data = audio_data[audio_index]
futures.append( futures.append(
self.io_executor.submit( self.io_executor.submit(
...@@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC): ...@@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images discard_alpha_channel: if True, discards the alpha channel in the returned images
""" """
if not return_text:
raise NotImplementedError()
if image_data is None: if image_data is None:
image_data = [] image_data = []
if isinstance(multimodal_tokens.image_token, int): if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = ( multimodal_tokens.image_token = re.compile(
re.escape(
self._processor.tokenizer.convert_ids_to_tokens( self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token multimodal_tokens.image_token
) )
) )
)
else: else:
multimodal_tokens.image_token = multimodal_tokens.image_token multimodal_tokens.image_token = multimodal_tokens.image_token
multimodal_tokens_pattern = multimodal_tokens.collect()
if isinstance(prompt, list) and return_text: if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int) assert len(prompt) and isinstance(prompt[0], int)
...@@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC): ...@@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC):
prompt = prompt prompt = prompt
assert isinstance(prompt, str) assert isinstance(prompt, str)
if return_text:
import re
pattern = (
"("
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
+ ")"
)
# split text into list of normal text and special tokens # split text into list of normal text and special tokens
text_parts = re.split(pattern, prompt) text_parts = re.split(multimodal_tokens_pattern, prompt)
futures, task_info = self.submit_data_loading_tasks( futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts, text_parts=text_parts,
...@@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC): ...@@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel=discard_alpha_channel, discard_alpha_channel=discard_alpha_channel,
) )
# Process results # Process results
image_sizes, images, audios = [], [], [] images, audios = [], []
new_text = "" new_text = ""
task_ptr = 0 task_ptr = 0
for text_part in text_parts: for text_part in text_parts:
if text_part in multimodal_tokens.collect(): if multimodal_tokens_pattern.match(text_part):
task_type, data, frame_limit = task_info[task_ptr] task_type, data, frame_limit = task_info[task_ptr]
result = futures[task_ptr].result() result = futures[task_ptr].result()
task_ptr += 1 task_ptr += 1
if task_type == Modality.IMAGE: if task_type == Modality.IMAGE:
# If data is already processed it will be a
# dictionary. In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.image_token
)
frames = [result] if not isinstance(result, list) else result frames = [result] if not isinstance(result, list) else result
if frames: if frames:
image_sizes += frames[0].size * len(frames)
images += frames images += frames
new_text += multimodal_tokens.image_token * len(frames) new_text += mm_tokens * len(frames)
elif task_type == Modality.AUDIO: elif task_type == Modality.AUDIO:
# audio # audio
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.audio_token
)
audios.append(result) audios.append(result)
new_text += multimodal_tokens.audio_token new_text += mm_tokens
# TODO: handle video # TODO: handle video
else: else:
new_text += text_part new_text += text_part
...@@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC): ...@@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC):
) )
out.normalize() out.normalize()
return out return out
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
if not mm_inputs:
return True
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
if ret and not all(
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
):
raise ValueError(
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
)
return ret
from typing import List, Union import re
from typing import Dict, List, Union
from sglang.srt.managers.multimodal_processor import ( from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
...@@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<start_of_image>" self.IMAGE_TOKEN = "<start_of_image>"
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes, Dict]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
...@@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_data = [image_data] image_data = [image_data]
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
image_token_regex = self.IMAGE_TOKEN_REGEX
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(
image_token=image_token, image_token_regex=image_token_regex
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
discard_alpha_channel=True, discard_alpha_channel=True,
) )
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
ret = self.process_mm_data( ret = self.process_mm_data(
input_text=base_output.input_text, images=base_output.images input_text=base_output.input_text,
images=None if images_are_preprocessed else base_output.images,
) )
items = [] items = []
for i, image in enumerate(base_output.images): for i, image in enumerate(base_output.images):
if images_are_preprocessed:
pixel_values = image.pixel_values
precomputed_features = image.precomputed_features
else:
pixel_values = ret["pixel_values"][i]
precomputed_features = None
item = MultimodalDataItem( item = MultimodalDataItem(
pixel_values=ret["pixel_values"][i], pixel_values=pixel_values,
precomputed_features=precomputed_features,
modality=Modality.IMAGE, modality=Modality.IMAGE,
) )
items += [item] items += [item]
......
import asyncio import asyncio
import math import math
from typing import List, Union import re
from typing import Dict, List, Union
import torch import torch
from PIL import Image from PIL import Image
...@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
)
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id self.image_token_id = hf_config.image_token_id
...@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes, Dict]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
...@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if isinstance(image_data, str): if isinstance(image_data, str):
image_data = [image_data] image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
...@@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def resize_image_async(image): async def resize_image_async(image):
return resize_image(image) return resize_image(image)
if base_output.images: images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
if base_output.images and not images_are_preprocessed:
resize_tasks = [resize_image_async(image) for image in base_output.images] resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks) base_output.images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data( ret = self.process_mm_data(
input_text=base_output.input_text, input_text=base_output.input_text,
images=base_output.images, images=None if images_are_preprocessed else base_output.images,
) )
input_ids = ret["input_ids"].flatten().tolist()
image_grid_thw = None
video_grid_thw = None # TODO
items = [] items = []
input_ids = ret["input_ids"].flatten().tolist() if base_output.images:
if "pixel_values" in ret: if images_are_preprocessed:
image_grid_thw = torch.concat(
[
torch.as_tensor(item.image_grid_thws)
for item in base_output.images
]
)
all_pixel_values = [
item.pixel_values
for item in base_output.images
if item.pixel_values is not None
]
all_precomputed_features = [
item.precomputed_features
for item in base_output.images
if item.precomputed_features is not None
]
pixel_values = (
torch.concat(all_pixel_values) if all_pixel_values else None
)
precomputed_features = (
torch.concat(all_precomputed_features)
if all_precomputed_features
else None
)
else:
image_grid_thw = ret["image_grid_thw"]
pixel_values = ret["pixel_values"]
precomputed_features = None
items += [ items += [
MultimodalDataItem( MultimodalDataItem(
pixel_values=ret["pixel_values"], pixel_values=pixel_values,
image_grid_thws=torch.concat([ret["image_grid_thw"]]), image_grid_thws=image_grid_thw,
# TODO video_grid_thws=video_grid_thw,
video_grid_thws=None, precomputed_features=precomputed_features,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
modality=Modality.IMAGE, modality=Modality.IMAGE,
) )
] ]
...@@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.hf_config.vision_config, "tokens_per_second", None self.hf_config.vision_config, "tokens_per_second", None
), ),
input_ids=torch.tensor(input_ids).unsqueeze(0), input_ids=torch.tensor(input_ids).unsqueeze(0),
image_grid_thw=ret.get("image_grid_thw", None), image_grid_thw=image_grid_thw,
video_grid_thw=ret.get("video_grid_thw", None), video_grid_thw=video_grid_thw,
second_per_grid_ts=ret.get("second_per_grid_ts", None), second_per_grid_ts=ret.get("second_per_grid_ts", None),
) )
mrope_positions = mrope_positions.squeeze(1) mrope_positions = mrope_positions.squeeze(1)
......
...@@ -177,10 +177,10 @@ class MultimodalDataItem: ...@@ -177,10 +177,10 @@ class MultimodalDataItem:
image_offsets: Optional[list] = None image_offsets: Optional[list] = None
# the real data, pixel_values or audio_features # the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.array]] # data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.array] = None pixel_values: Union[torch.Tensor, np.ndarray] = None
image_grid_thws: Union[torch.Tensor, np.array] = None image_grid_thws: Union[torch.Tensor, np.ndarray] = None
video_grid_thws: Union[torch.Tensor, np.array] = None video_grid_thws: Union[torch.Tensor, np.ndarray] = None
image_emb_mask: Optional[torch.Tensor] = None image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None image_spatial_crop: Optional[torch.Tensor] = None
...@@ -189,9 +189,11 @@ class MultimodalDataItem: ...@@ -189,9 +189,11 @@ class MultimodalDataItem:
# [num_images, (n, w, h)] # [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None tgt_size: Tuple[int, int] = None
audio_features: Union[torch.Tensor, np.array] = None audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None audio_feature_lens: Optional[List[torch.Tensor]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
@staticmethod @staticmethod
def is_empty_list(l): def is_empty_list(l):
if l is None: if l is None:
...@@ -249,7 +251,9 @@ class MultimodalDataItem: ...@@ -249,7 +251,9 @@ class MultimodalDataItem:
return tensor_hash([f]) return tensor_hash([f])
return data_hash(f) return data_hash(f)
if self.is_audio(): if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
elif self.is_audio():
self.hash = hash_feature(self.audio_features) self.hash = hash_feature(self.audio_features)
else: else:
self.hash = hash_feature(self.pixel_values) self.hash = hash_feature(self.pixel_values)
...@@ -258,19 +262,24 @@ class MultimodalDataItem: ...@@ -258,19 +262,24 @@ class MultimodalDataItem:
self.pad_value = self.hash % (1 << 30) self.pad_value = self.hash % (1 << 30)
def is_audio(self): def is_audio(self):
return ( return (self.modality == Modality.AUDIO) and (
self.modality == Modality.AUDIO self.precomputed_features is not None
) and not MultimodalDataItem.is_empty_list(self.audio_features) or not MultimodalDataItem.is_empty_list(self.audio_features)
)
def is_image(self): def is_image(self):
return ( return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
) and not MultimodalDataItem.is_empty_list(self.pixel_values) ) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
)
def is_video(self): def is_video(self):
return ( return (self.modality == Modality.VIDEO) and (
self.modality == Modality.VIDEO self.precomputed_features is not None
) and not MultimodalDataItem.is_empty_list(self.pixel_values) or not MultimodalDataItem.is_empty_list(self.pixel_values)
)
def is_valid(self) -> bool: def is_valid(self) -> bool:
return self.is_image() or self.is_video() or self.is_audio() return self.is_image() or self.is_video() or self.is_audio()
...@@ -279,6 +288,16 @@ class MultimodalDataItem: ...@@ -279,6 +288,16 @@ class MultimodalDataItem:
... ...
# TODO # TODO
@staticmethod
def from_dict(obj: dict):
kwargs = dict(obj)
modality = kwargs.pop("modality")
if isinstance(modality, str):
modality = Modality[modality]
ret = MultimodalDataItem(modality=modality, **kwargs)
ret.validate()
return ret
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalInputs: class MultimodalInputs:
......
...@@ -54,7 +54,7 @@ class SessionReqNode: ...@@ -54,7 +54,7 @@ class SessionReqNode:
prefix += " -- " + self.childs[0].req.rid prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix) ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]: for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + " \- " + child.req.rid prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
ret += child._str_helper(prefix) ret += child._str_helper(prefix)
return ret return ret
......
...@@ -278,6 +278,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -278,6 +278,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns: Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
pixel_values = torch.stack( pixel_values = torch.stack(
flatten_nested_list([item.pixel_values for item in items]), dim=0 flatten_nested_list([item.pixel_values for item in items]), dim=0
) )
......
...@@ -497,6 +497,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -497,6 +497,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same # in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype self.visual.dtype
......
...@@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same # in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype self.visual.dtype
......
...@@ -54,21 +54,17 @@ class TestSkipTokenizerInit(CustomTestCase): ...@@ -54,21 +54,17 @@ class TestSkipTokenizerInit(CustomTestCase):
): ):
input_ids = self.get_input_ids(prompt_text) input_ids = self.get_input_ids(prompt_text)
request = self.get_request_json(
input_ids=input_ids,
return_logprob=return_logprob,
top_logprobs_num=top_logprobs_num,
max_new_tokens=max_new_tokens,
stream=False,
n=n,
)
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json=request,
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [self.tokenizer.eos_token_id],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
) )
ret = response.json() ret = response.json()
print(json.dumps(ret, indent=2)) print(json.dumps(ret, indent=2))
...@@ -87,9 +83,12 @@ class TestSkipTokenizerInit(CustomTestCase): ...@@ -87,9 +83,12 @@ class TestSkipTokenizerInit(CustomTestCase):
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids)) self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
if return_logprob: if return_logprob:
num_input_logprobs = len(input_ids) - request["logprob_start_len"]
if num_input_logprobs > len(input_ids):
num_input_logprobs -= len(input_ids)
self.assertEqual( self.assertEqual(
len(item["meta_info"]["input_token_logprobs"]), len(item["meta_info"]["input_token_logprobs"]),
len(input_ids), num_input_logprobs,
f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}', f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
) )
self.assertEqual( self.assertEqual(
...@@ -113,19 +112,14 @@ class TestSkipTokenizerInit(CustomTestCase): ...@@ -113,19 +112,14 @@ class TestSkipTokenizerInit(CustomTestCase):
requests.post(self.base_url + "/flush_cache") requests.post(self.base_url + "/flush_cache")
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json=self.get_request_json(
"input_ids": input_ids, input_ids=input_ids,
"sampling_params": { max_new_tokens=max_new_tokens,
"temperature": 0 if n == 1 else 0.5, return_logprob=return_logprob,
"max_new_tokens": max_new_tokens, top_logprobs_num=top_logprobs_num,
"n": n, stream=False,
"stop_token_ids": self.eos_token_id, n=n,
}, ),
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
) )
ret = response.json() ret = response.json()
print(json.dumps(ret)) print(json.dumps(ret))
...@@ -137,19 +131,13 @@ class TestSkipTokenizerInit(CustomTestCase): ...@@ -137,19 +131,13 @@ class TestSkipTokenizerInit(CustomTestCase):
requests.post(self.base_url + "/flush_cache") requests.post(self.base_url + "/flush_cache")
response_stream = requests.post( response_stream = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json=self.get_request_json(
"input_ids": input_ids, input_ids=input_ids,
"sampling_params": { return_logprob=return_logprob,
"temperature": 0 if n == 1 else 0.5, top_logprobs_num=top_logprobs_num,
"max_new_tokens": max_new_tokens, stream=True,
"n": n, n=n,
"stop_token_ids": self.eos_token_id, ),
},
"stream": True,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
) )
response_stream_json = [] response_stream_json = []
...@@ -188,6 +176,29 @@ class TestSkipTokenizerInit(CustomTestCase): ...@@ -188,6 +176,29 @@ class TestSkipTokenizerInit(CustomTestCase):
].tolist() ].tolist()
return input_ids return input_ids
def get_request_json(
self,
input_ids,
max_new_tokens=32,
return_logprob=False,
top_logprobs_num=0,
stream=False,
n=1,
):
return {
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": self.eos_token_id,
},
"stream": stream,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
}
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit): class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
@classmethod @classmethod
...@@ -218,6 +229,14 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit): ...@@ -218,6 +229,14 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
return inputs.input_ids[0].tolist() return inputs.input_ids[0].tolist()
def get_request_json(self, *args, **kwargs):
ret = super().get_request_json(*args, **kwargs)
ret["image_data"] = [self.image_url]
ret["logprob_start_len"] = (
-1
) # Do not try to calculate logprobs of image embeddings.
return ret
def test_simple_decode_stream(self): def test_simple_decode_stream(self):
# TODO mick # TODO mick
pass pass
......
...@@ -3,15 +3,22 @@ ...@@ -3,15 +3,22 @@
import unittest import unittest
from io import BytesIO from io import BytesIO
from typing import List from typing import List, Optional
import numpy as np import numpy as np
import requests import requests
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer from transformers import (
AutoModel,
AutoProcessor,
AutoTokenizer,
Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs from sglang.srt.managers.mm_utils import embed_mm_inputs
...@@ -100,7 +107,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): ...@@ -100,7 +107,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
np.testing.assert_allclose(hf_np, sg_np) np.testing.assert_allclose(hf_np, sg_np)
def get_processor_output(self): def get_completion_request(self) -> ChatCompletionRequest:
json_str = f""" json_str = f"""
{{ {{
"model": "{self.model_path}", "model": "{self.model_path}",
...@@ -124,10 +131,12 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): ...@@ -124,10 +131,12 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
}} }}
""" """
req = ChatCompletionRequest.model_validate_json(json_str) return ChatCompletionRequest.model_validate_json(json_str)
def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
if req is None:
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template) conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt() text = conv.get_prompt()
# Process inputs using processor # Process inputs using processor
...@@ -239,5 +248,129 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): ...@@ -239,5 +248,129 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
self.compare_outputs(sglang_output, hf_output) self.compare_outputs(sglang_output, hf_output)
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
cls.chat_template = "qwen2-vl"
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
cls.visual = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
.eval()
.visual.to(cls.device)
)
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.8,
)
def tearDown(self):
self.engine.shutdown()
async def test_qwen_vl_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_qwen_vl_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.visual(
processor_output["pixel_values"], processor_output["image_grid_thw"]
)
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_thw"],
precomputed_features=precomputed_features,
)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = "google/gemma-3-4b-it"
cls.chat_template = "gemma-it"
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
model = Gemma3ForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
@classmethod
def visual(cls, pixel_values):
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = cls.mm_projector(vision_outputs)
return image_features
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.5,
enable_multimodal=True,
)
def tearDown(self):
self.engine.shutdown()
async def test_gemma_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_gemma_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.visual(processor_output["pixel_values"])
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
dict(
modality="IMAGE",
precomputed_features=precomputed_features,
)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment