Unverified Commit f19a9204 authored by Yury Sulsky's avatar Yury Sulsky Committed by GitHub
Browse files

Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)


Co-authored-by: default avatarYury Sulsky <ysulsky@tesla.com>
parent c23a7072
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Querying Qwen-VL"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply() # Run this first.\n",
"\n",
"model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n",
"chat_template = \"qwen2-vl\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"# Lets create a prompt.\n",
"\n",
"from io import BytesIO\n",
"import requests\n",
"from PIL import Image\n",
"\n",
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
"from sglang.srt.conversation import chat_templates\n",
"\n",
"image = Image.open(\n",
" BytesIO(\n",
" requests.get(\n",
" \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
" ).content\n",
" )\n",
")\n",
"\n",
"conv = chat_templates[chat_template].copy()\n",
"conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
"conv.append_message(conv.roles[1], \"\")\n",
"conv.image_data = [image]\n",
"\n",
"print(conv.get_prompt())\n",
"image"
]
},
{
"cell_type": "markdown",
"id": "3",
"metadata": {},
"source": [
"## Query via the offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"from sglang import Engine\n",
"\n",
"llm = Engine(\n",
" model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
"print(out[\"text\"])"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Query via the offline Engine API, but send precomputed embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"# Compute the image embeddings using Huggingface.\n",
"\n",
"from transformers import AutoProcessor\n",
"from transformers import Qwen2_5_VLForConditionalGeneration\n",
"\n",
"processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
"vision = (\n",
" Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"processed_prompt = processor(\n",
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
")\n",
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
"precomputed_features = vision(\n",
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
")\n",
"\n",
"mm_item = dict(\n",
" modality=\"IMAGE\",\n",
" image_grid_thws=processed_prompt[\"image_grid_thw\"],\n",
" precomputed_features=precomputed_features,\n",
")\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
"print(out[\"text\"])"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"custom_cell_magics": "kql",
"encoding": "# -*- coding: utf-8 -*-"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
ImageDataItem,
InitWeightsUpdateGroupReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
......@@ -150,9 +151,9 @@ class Engine(EngineBase):
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
List[List[ImageDataItem]],
List[ImageDataItem],
ImageDataItem,
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
......@@ -221,9 +222,9 @@ class Engine(EngineBase):
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[
List[List[Union[Image, str]]],
List[Union[Image, str]],
Union[Image, str],
List[List[ImageDataItem]],
List[ImageDataItem],
ImageDataItem,
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
......
......@@ -40,6 +40,10 @@ class SessionParams:
replace: Optional[bool] = None
AudioDataItem = Union[str, Dict]
ImageDataItem = Union[Image, str, Dict]
@dataclass
class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
......@@ -55,10 +59,10 @@ class GenerateReqInput:
# - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
......
......@@ -368,13 +368,13 @@ def general_mm_embed_routine(
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
language_model: nn.Module,
image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
image_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
audio_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
**kwargs,
) -> torch.Tensor:
"""
......@@ -389,7 +389,6 @@ def general_mm_embed_routine(
forwarded hidden states
"""
assert hasattr(language_model, "get_input_embeddings")
embed_tokens = language_model.get_input_embeddings()
if (
......
......@@ -3,16 +3,16 @@ import concurrent.futures
import dataclasses
import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image
......@@ -22,13 +22,13 @@ class BaseMultiModalProcessorOutput:
input_text: str
# frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
# audios
audios: Optional[list[np.ndarray]] = None
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
def normalize(self):
for field_name in ["image_sizes", "images", "audios"]:
for field_name in ["images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
......@@ -40,12 +40,32 @@ class MultimodalSpecialTokens:
video_token: Optional[str] = None
audio_token: Optional[str] = None
def collect(self) -> list[str]:
return [
token
for token in [self.image_token, self.video_token, self.audio_token]
if token
image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None
def __post_init__(self):
if self.image_token_regex is None and self.image_token is not None:
self.image_token_regex = re.compile(re.escape(self.image_token))
if self.video_token_regex is None and self.video_token is not None:
self.video_token_regex = re.compile(re.escape(self.video_token))
if self.audio_token_regex is None and self.audio_token is not None:
self.audio_token_regex = re.compile(re.escape(self.audio_token))
def collect(self) -> re.Pattern:
tokens = [
self.image_token_regex,
self.video_token_regex,
self.audio_token_regex,
]
patterns = []
flags = 0
for t in tokens:
if t is not None:
patterns.append(t.pattern)
flags |= t.flags
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
return re.compile(combined, flags)
class BaseMultimodalProcessor(ABC):
......@@ -136,6 +156,10 @@ class BaseMultimodalProcessor(ABC):
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
):
"""Static method that can be pickled for multiprocessing"""
if isinstance(data, dict):
return MultimodalDataItem.from_dict(data)
if isinstance(data, MultimodalDataItem):
return data
try:
if is_audio:
return load_audio(data)
......@@ -175,7 +199,10 @@ class BaseMultimodalProcessor(ABC):
image_index, audio_index = 0, 0
for text_part in text_parts:
if text_part == multimodal_tokens.image_token:
if (
multimodal_tokens.image_token_regex
and multimodal_tokens.image_token_regex.match(text_part)
):
data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
......@@ -192,7 +219,10 @@ class BaseMultimodalProcessor(ABC):
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
elif text_part == multimodal_tokens.audio_token:
elif (
multimodal_tokens.audio_token_regex
and multimodal_tokens.audio_token_regex.match(text_part)
):
data = audio_data[audio_index]
futures.append(
self.io_executor.submit(
......@@ -228,17 +258,22 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if not return_text:
raise NotImplementedError()
if image_data is None:
image_data = []
if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = (
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
multimodal_tokens.image_token = re.compile(
re.escape(
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
)
)
)
else:
multimodal_tokens.image_token = multimodal_tokens.image_token
multimodal_tokens_pattern = multimodal_tokens.collect()
if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
......@@ -247,16 +282,8 @@ class BaseMultimodalProcessor(ABC):
prompt = prompt
assert isinstance(prompt, str)
if return_text:
import re
pattern = (
"("
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
+ ")"
)
# split text into list of normal text and special tokens
text_parts = re.split(pattern, prompt)
# split text into list of normal text and special tokens
text_parts = re.split(multimodal_tokens_pattern, prompt)
futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts,
......@@ -266,26 +293,40 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel=discard_alpha_channel,
)
# Process results
image_sizes, images, audios = [], [], []
images, audios = [], []
new_text = ""
task_ptr = 0
for text_part in text_parts:
if text_part in multimodal_tokens.collect():
if multimodal_tokens_pattern.match(text_part):
task_type, data, frame_limit = task_info[task_ptr]
result = futures[task_ptr].result()
task_ptr += 1
if task_type == Modality.IMAGE:
# If data is already processed it will be a
# dictionary. In this case we want to keep the
# expanded tokens in text_part. Otherwise, we will
# call the processor code, so keep only a single image
# token.
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.image_token
)
frames = [result] if not isinstance(result, list) else result
if frames:
image_sizes += frames[0].size * len(frames)
images += frames
new_text += multimodal_tokens.image_token * len(frames)
new_text += mm_tokens * len(frames)
elif task_type == Modality.AUDIO:
# audio
mm_tokens = (
text_part
if isinstance(data, dict)
else multimodal_tokens.audio_token
)
audios.append(result)
new_text += multimodal_tokens.audio_token
new_text += mm_tokens
# TODO: handle video
else:
new_text += text_part
......@@ -297,3 +338,16 @@ class BaseMultimodalProcessor(ABC):
)
out.normalize()
return out
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
if not mm_inputs:
return True
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
if ret and not all(
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
):
raise ValueError(
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
)
return ret
from typing import List, Union
import re
from typing import Dict, List, Union
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
......@@ -18,13 +19,18 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<start_of_image>"
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
......@@ -37,22 +43,35 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
image_token_regex = self.IMAGE_TOKEN_REGEX
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
multimodal_tokens=MultimodalSpecialTokens(
image_token=image_token, image_token_regex=image_token_regex
),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
ret = self.process_mm_data(
input_text=base_output.input_text, images=base_output.images
input_text=base_output.input_text,
images=None if images_are_preprocessed else base_output.images,
)
items = []
for i, image in enumerate(base_output.images):
if images_are_preprocessed:
pixel_values = image.pixel_values
precomputed_features = image.precomputed_features
else:
pixel_values = ret["pixel_values"][i]
precomputed_features = None
item = MultimodalDataItem(
pixel_values=ret["pixel_values"][i],
pixel_values=pixel_values,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
)
items += [item]
......
import asyncio
import math
from typing import List, Union
import re
from typing import Dict, List, Union
import torch
from PIL import Image
......@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
)
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id
......@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
......@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
),
max_req_input_len=max_req_input_len,
)
......@@ -117,26 +125,56 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def resize_image_async(image):
return resize_image(image)
if base_output.images:
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
if base_output.images and not images_are_preprocessed:
resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
images=None if images_are_preprocessed else base_output.images,
)
input_ids = ret["input_ids"].flatten().tolist()
image_grid_thw = None
video_grid_thw = None # TODO
items = []
input_ids = ret["input_ids"].flatten().tolist()
if "pixel_values" in ret:
if base_output.images:
if images_are_preprocessed:
image_grid_thw = torch.concat(
[
torch.as_tensor(item.image_grid_thws)
for item in base_output.images
]
)
all_pixel_values = [
item.pixel_values
for item in base_output.images
if item.pixel_values is not None
]
all_precomputed_features = [
item.precomputed_features
for item in base_output.images
if item.precomputed_features is not None
]
pixel_values = (
torch.concat(all_pixel_values) if all_pixel_values else None
)
precomputed_features = (
torch.concat(all_precomputed_features)
if all_precomputed_features
else None
)
else:
image_grid_thw = ret["image_grid_thw"]
pixel_values = ret["pixel_values"]
precomputed_features = None
items += [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
# TODO
video_grid_thws=None,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
pixel_values=pixel_values,
image_grid_thws=image_grid_thw,
video_grid_thws=video_grid_thw,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
)
]
......@@ -151,8 +189,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=torch.tensor(input_ids).unsqueeze(0),
image_grid_thw=ret.get("image_grid_thw", None),
video_grid_thw=ret.get("video_grid_thw", None),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
)
mrope_positions = mrope_positions.squeeze(1)
......
......@@ -177,10 +177,10 @@ class MultimodalDataItem:
image_offsets: Optional[list] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.array]]
pixel_values: Union[torch.Tensor, np.array] = None
image_grid_thws: Union[torch.Tensor, np.array] = None
video_grid_thws: Union[torch.Tensor, np.array] = None
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray] = None
image_grid_thws: Union[torch.Tensor, np.ndarray] = None
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
......@@ -189,9 +189,11 @@ class MultimodalDataItem:
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None
audio_features: Union[torch.Tensor, np.array] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
@staticmethod
def is_empty_list(l):
if l is None:
......@@ -249,7 +251,9 @@ class MultimodalDataItem:
return tensor_hash([f])
return data_hash(f)
if self.is_audio():
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
elif self.is_audio():
self.hash = hash_feature(self.audio_features)
else:
self.hash = hash_feature(self.pixel_values)
......@@ -258,19 +262,24 @@ class MultimodalDataItem:
self.pad_value = self.hash % (1 << 30)
def is_audio(self):
return (
self.modality == Modality.AUDIO
) and not MultimodalDataItem.is_empty_list(self.audio_features)
return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.audio_features)
)
def is_image(self):
return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
)
def is_video(self):
return (
self.modality == Modality.VIDEO
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
return (self.modality == Modality.VIDEO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
)
def is_valid(self) -> bool:
return self.is_image() or self.is_video() or self.is_audio()
......@@ -279,6 +288,16 @@ class MultimodalDataItem:
...
# TODO
@staticmethod
def from_dict(obj: dict):
kwargs = dict(obj)
modality = kwargs.pop("modality")
if isinstance(modality, str):
modality = Modality[modality]
ret = MultimodalDataItem(modality=modality, **kwargs)
ret.validate()
return ret
@dataclasses.dataclass
class MultimodalInputs:
......
......@@ -54,7 +54,7 @@ class SessionReqNode:
prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
ret += child._str_helper(prefix)
return ret
......
......@@ -278,6 +278,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
pixel_values = torch.stack(
flatten_nested_list([item.pixel_values for item in items]), dim=0
)
......
......@@ -497,6 +497,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
......
......@@ -486,6 +486,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
......
......@@ -54,21 +54,17 @@ class TestSkipTokenizerInit(CustomTestCase):
):
input_ids = self.get_input_ids(prompt_text)
request = self.get_request_json(
input_ids=input_ids,
return_logprob=return_logprob,
top_logprobs_num=top_logprobs_num,
max_new_tokens=max_new_tokens,
stream=False,
n=n,
)
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [self.tokenizer.eos_token_id],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
json=request,
)
ret = response.json()
print(json.dumps(ret, indent=2))
......@@ -87,9 +83,12 @@ class TestSkipTokenizerInit(CustomTestCase):
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
if return_logprob:
num_input_logprobs = len(input_ids) - request["logprob_start_len"]
if num_input_logprobs > len(input_ids):
num_input_logprobs -= len(input_ids)
self.assertEqual(
len(item["meta_info"]["input_token_logprobs"]),
len(input_ids),
num_input_logprobs,
f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
)
self.assertEqual(
......@@ -113,19 +112,14 @@ class TestSkipTokenizerInit(CustomTestCase):
requests.post(self.base_url + "/flush_cache")
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": self.eos_token_id,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
json=self.get_request_json(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
return_logprob=return_logprob,
top_logprobs_num=top_logprobs_num,
stream=False,
n=n,
),
)
ret = response.json()
print(json.dumps(ret))
......@@ -137,19 +131,13 @@ class TestSkipTokenizerInit(CustomTestCase):
requests.post(self.base_url + "/flush_cache")
response_stream = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": self.eos_token_id,
},
"stream": True,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
json=self.get_request_json(
input_ids=input_ids,
return_logprob=return_logprob,
top_logprobs_num=top_logprobs_num,
stream=True,
n=n,
),
)
response_stream_json = []
......@@ -188,6 +176,29 @@ class TestSkipTokenizerInit(CustomTestCase):
].tolist()
return input_ids
def get_request_json(
self,
input_ids,
max_new_tokens=32,
return_logprob=False,
top_logprobs_num=0,
stream=False,
n=1,
):
return {
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": self.eos_token_id,
},
"stream": stream,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
}
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
@classmethod
......@@ -218,6 +229,14 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
return inputs.input_ids[0].tolist()
def get_request_json(self, *args, **kwargs):
ret = super().get_request_json(*args, **kwargs)
ret["image_data"] = [self.image_url]
ret["logprob_start_len"] = (
-1
) # Do not try to calculate logprobs of image embeddings.
return ret
def test_simple_decode_stream(self):
# TODO mick
pass
......
......@@ -3,15 +3,22 @@
import unittest
from io import BytesIO
from typing import List
from typing import List, Optional
import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from transformers import (
AutoModel,
AutoProcessor,
AutoTokenizer,
Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs
......@@ -100,7 +107,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
np.testing.assert_allclose(hf_np, sg_np)
def get_processor_output(self):
def get_completion_request(self) -> ChatCompletionRequest:
json_str = f"""
{{
"model": "{self.model_path}",
......@@ -124,10 +131,12 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
}}
"""
req = ChatCompletionRequest.model_validate_json(json_str)
return ChatCompletionRequest.model_validate_json(json_str)
def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
if req is None:
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
# Process inputs using processor
......@@ -239,5 +248,129 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
self.compare_outputs(sglang_output, hf_output)
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
cls.chat_template = "qwen2-vl"
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
cls.visual = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
.eval()
.visual.to(cls.device)
)
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.8,
)
def tearDown(self):
self.engine.shutdown()
async def test_qwen_vl_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_qwen_vl_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.visual(
processor_output["pixel_values"], processor_output["image_grid_thw"]
)
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_thw"],
precomputed_features=precomputed_features,
)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = "google/gemma-3-4b-it"
cls.chat_template = "gemma-it"
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
model = Gemma3ForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
@classmethod
def visual(cls, pixel_values):
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = cls.mm_projector(vision_outputs)
return image_features
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.5,
enable_multimodal=True,
)
def tearDown(self):
self.engine.shutdown()
async def test_gemma_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_gemma_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.visual(processor_output["pixel_values"])
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
dict(
modality="IMAGE",
precomputed_features=precomputed_features,
)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment