Unverified Commit 11577ced authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: bug fixes and refactor for vlm (#4661)

parent ca75741e
"""
Bench the sglang-hosted vLM with benchmark MMMU
Bench the sglang-hosted vLM with benchmark MMMU
Usage:
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
Usage:
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
The eval output will be logged
The eval output will be logged
"""
import argparse
import time
import openai
from data_utils import save_json
......@@ -37,6 +38,7 @@ def eval_mmmu(args):
# had to use an openai server, since SglImage doesn't support image data
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
start = time.time()
for i, sample in enumerate(tqdm(samples)):
prompt = sample["final_input_prompt"]
prefix = prompt.split("<")[0]
......@@ -73,6 +75,8 @@ def eval_mmmu(args):
response = response.choices[0].message.content
process_result(response, sample, answer_dict, out_samples)
print(f"Benchmark time: {time.time() - start}")
args.output_path = f"./val_sglang.json"
save_json(args.output_path, out_samples)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
......
......@@ -9,8 +9,6 @@ import PIL
import torch
from PIL.Image import Image
from transformers import (
AutoImageProcessor,
AutoProcessor,
BaseImageProcessor,
BatchFeature,
LlamaConfig,
......@@ -20,6 +18,7 @@ from transformers import (
)
from transformers.image_utils import to_numpy_array
from sglang.srt.configs.utils import register_image_processor, register_processor
from sglang.srt.mm_utils import expand2square
......@@ -625,5 +624,5 @@ class VLMImageProcessorConfig(PretrainedConfig):
super().__init__(**kwargs)
AutoProcessor.register(MultiModalityConfig, VLChatProcessor, exist_ok=True)
AutoImageProcessor.register(VLMImageProcessorConfig, None, VLMImageProcessor, None)
register_processor(MultiModalityConfig, VLChatProcessor)
register_image_processor(MultiModalityConfig, VLMImageProcessor)
......@@ -460,6 +460,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs = [
"DeepseekVL2ForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaQwenForCausalLM",
"LlavaMistralForCausalLM",
......@@ -472,7 +473,6 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration",
"MiniCPMV",
"MultiModalityCausalLM",
"DeepseekVL2ForCausalLM",
]
......
from typing import Type
from transformers import (
AutoImageProcessor,
AutoProcessor,
BaseImageProcessor,
PretrainedConfig,
ProcessorMixin,
)
def register_image_processor(
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
):
"""
register customized hf image processor while removing hf impl
"""
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
"""
register customized hf processor while removing hf impl
"""
AutoProcessor.register(config, processor, exist_ok=True)
......@@ -653,7 +653,7 @@ register_conv_template(
Conversation(
name="gemma-it",
system_message="You are a helpful assistant.",
system_template="<bos><start_of_turn>user{system_message}\n\n",
system_template="<start_of_turn>user{system_message}\n\n",
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
sep="<end_of_turn>\n",
sep_style=SeparatorStyle.GEMMA3,
......
......@@ -143,9 +143,14 @@ class VisionAttention(nn.Module):
if position_embeddings is not None:
cos, sin = position_embeddings
original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1)
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape)
q = q.view(original_shape)
k = k.view(original_shape)
if self.use_qkv_parallel:
pass
......
# TODO: also move pad_input_ids into this module
import importlib
import inspect
import logging
import pkgutil
from functools import lru_cache
from typing import Union
from torch import Tensor
from transformers import IMAGE_PROCESSOR_MAPPING
from sglang.srt.managers.image_processors.base_image_processor import (
......@@ -18,9 +21,7 @@ logger = logging.getLogger(__name__)
IMAGE_PROCESSOR_MAPPING = {}
def get_image_processor(
hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor:
def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor:
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
......@@ -42,13 +43,18 @@ def import_image_processors():
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
logger.warning(f" Ignore import error when loading {name}: " f"{e}")
continue
if hasattr(module, "ImageProcessorMapping"):
entry = module.ImageProcessorMapping
if isinstance(entry, dict):
for processor_name, cls in entry.items():
IMAGE_PROCESSOR_MAPPING[processor_name] = cls
all_members = inspect.getmembers(module, inspect.isclass)
classes = [
member
for name, member in all_members
if member.__module__ == module.__name__
]
for cls in classes:
if issubclass(cls, BaseImageProcessor):
for arch in getattr(cls, "models"):
IMAGE_PROCESSOR_MAPPING[arch] = cls
# also register processors
......
......@@ -4,14 +4,14 @@ import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union
import PIL
import transformers
from decord import VideoReader, cpu
from openai import BadRequestError
from PIL import Image
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image
from sglang.utils import logger
......@@ -31,8 +31,16 @@ class BaseImageProcessorOutput:
# input_text, with each frame of video/image represented as an image_token
input_text: str
def normalize(self):
for field_name in ["data_hashes", "image_sizes", "all_frames"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
class BaseImageProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
......@@ -40,6 +48,9 @@ class BaseImageProcessor(ABC):
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
# Initialize global processor first
init_global_processor(self, server_args)
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
......@@ -113,7 +124,7 @@ class BaseImageProcessor(ABC):
self,
input_ids: list[int],
image_data,
image_token: str,
image_token: Union[int, str],
max_req_input_len: int,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
......@@ -122,9 +133,16 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token
Args:
image_token: The token ID representing the image placeholder.
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if isinstance(image_token, int):
image_token_str = self._processor.tokenizer.convert_ids_to_tokens(
image_token
)
else:
image_token_str = image_token
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
......@@ -190,13 +208,11 @@ class BaseImageProcessor(ABC):
new_text += text_part
except Exception as e:
import openai
logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError(
f"An exception occurred while loading images: {e}"
)
continue
return BaseImageProcessorOutput(
image_hashes=hashes,
......@@ -204,6 +220,8 @@ class BaseImageProcessor(ABC):
all_frames=images,
input_text=new_text,
)
out.normalize()
return out
class DummyImageProcessor(BaseImageProcessor):
......@@ -214,9 +232,7 @@ class DummyImageProcessor(BaseImageProcessor):
return None
def init_global_processor(
sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
):
def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args):
"""Init the global processor for multi-modal models."""
global global_processor
transformers.logging.set_verbosity_error()
......
......@@ -16,13 +16,9 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio
import math
from typing import List, Union
import torch
from PIL import Image, ImageOps
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
......@@ -32,18 +28,24 @@ from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class DeepseekVL2ImageProcessor(BaseImageProcessor):
models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor):
# with contextlib.suppress(ValueError):
# AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>"
@staticmethod
def _process_images_task(image, input_text, max_req_input_len):
return get_global_processor().__call__(
processor = get_global_processor()
res = processor.__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len
)
image_token_id = processor.image_token_id
res["im_token_id"] = image_token_id
return res
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
......@@ -70,18 +72,15 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
if not isinstance(image_data, list):
image_data = [image_data]
images, image_hashes, image_sizes = [], [], []
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids, image_data, image_token, max_req_input_len
)
base_output.all_frames = [img.convert("RGB") for img in base_output.all_frames]
res = await self._process_images(
base_output.all_frames, base_output.input_text, max_req_input_len
)
pixel_values = res["images"]
input_ids = res["input_ids"]
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
batched_images_spatial_crop = []
......@@ -89,16 +88,12 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
return {
"input_ids": input_ids.tolist(),
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"],
"im_token_id": res["im_token_id"],
"image_hashes": base_output.image_hashes,
"image_sizes": image_sizes,
"image_seq_mask": images_seq_mask,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"],
}
ImageProcessorMapping = {
DeepseekVL2ForCausalLM: DeepseekVL2ImageProcessor,
}
......@@ -17,14 +17,15 @@ logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<start_of_image>"
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
@staticmethod
def _process_images_task(images, input_text, _hf_config):
async def _process_single_image(self, images, input_text) -> dict:
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
......@@ -46,19 +47,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"pixel_values": pixel_values,
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Gemma3SGLangImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
......@@ -82,7 +70,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
discard_alpha_channel=True,
)
ret = await self._process_images(
ret = await self._process_single_image(
input_text=base_output.input_text, images=base_output.all_frames
)
......@@ -93,8 +81,3 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
ImageProcessorMapping = {
Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor,
}
......@@ -11,6 +11,8 @@ from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
class JanusProProcessor(SGLangBaseImageProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
......@@ -77,6 +79,3 @@ class JanusProProcessor(SGLangBaseImageProcessor):
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],
}
ImageProcessorMapping = {MultiModalityCausalLM: JanusProProcessor}
......@@ -15,6 +15,8 @@ from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseImageProcessor):
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
......@@ -143,10 +145,3 @@ class LlavaImageProcessor(BaseImageProcessor):
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
ImageProcessorMapping = {
LlavaVidForCausalLM: LlavaImageProcessor,
LlavaQwenForCausalLM: LlavaImageProcessor,
LlavaMistralForCausalLM: LlavaImageProcessor,
}
import asyncio
from typing import List, Union
import torch
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
......@@ -9,6 +11,8 @@ from sglang.srt.models.minicpmv import MiniCPMV
class MiniCPMVImageProcessor(BaseImageProcessor):
models = [MiniCPMV]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)"
......@@ -69,21 +73,57 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = tokenizer.im_start_id
im_token_id = tokenizer.unk_token_id
im_end_id = tokenizer.im_end_id
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
# per patch
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
tgt_sizes = torch.stack(tgt_sizes_flat)
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"tgt_sizes": res["tgt_sizes"],
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_token_id": im_token_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}
......@@ -10,6 +10,8 @@ from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseImageProcessor):
models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
......@@ -55,6 +57,3 @@ class MllamaImageProcessor(BaseImageProcessor):
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor}
......@@ -2,6 +2,7 @@ import asyncio
import math
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.image_processor import BaseImageProcessor
......@@ -14,6 +15,8 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(BaseImageProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
......@@ -43,7 +46,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_images(self, images, input_text) -> dict:
async def _process_single_image(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
......@@ -138,23 +141,23 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
images = [resize_image(image) for image in base_output.all_frames]
ret = await self._process_images(images, base_output.input_text)
ret = await self._process_single_image(
images=images, input_text=base_output.input_text
)
image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": ret["image_grid_thw"],
"video_grid_thws": ret["video_grid_thws"],
"image_grid_thws": image_grid_thws,
"video_grid_thws": video_grid_thws,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
}
ImageProcessorMapping = {
Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
}
"""
Multimodality utils
"""
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
from sglang.srt.managers.schedule_batch import ImageInputs
import torch
from torch import nn
from sglang.srt.managers.schedule_batch import (
ImageInputs,
global_server_args_dict,
logger,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.utils import logger
......@@ -115,7 +127,6 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
......@@ -132,3 +143,161 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
def __init__(self, image_token_id: torch.Tensor) -> None:
self.image_token_id = image_token_id
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
"""
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = image_inputs.pad_values
assert len(pad_values) != 0
input_ids_tensor = torch.tensor(input_ids)
mask = torch.isin(input_ids_tensor, self.image_token_id)
num_image_tokens = mask.sum().item()
repeated_pad_values = torch.tensor(pad_values).repeat(
num_image_tokens // len(pad_values) + 1
)[:num_image_tokens]
input_ids_tensor[mask] = repeated_pad_values
return input_ids_tensor.tolist()
def embed_image_inputs(
image_input: ImageInputs,
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_embedding_func,
placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]:
"""
Calculate the image embeddings if necessary, then scatter the result with
the help of a boolean mask denoting the embed locations
Returns:
final embedding: Optional[torch.Tensor]
"""
if image_input is None:
return None
placeholder_token_ids = placeholder_token_ids or image_input.pad_values
# boolean masking the special tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(placeholder_token_ids, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
if num_image_tokens_in_input_ids == 0:
# unexpected
inputs_embeds = input_embedding(input_ids)
else:
image_embedding = image_embedding_func(image_input)
if image_embedding.dim() == 2:
num_image_tokens_in_embedding = image_embedding.shape[0]
else:
num_image_tokens_in_embedding = (
image_embedding.shape[0] * image_embedding.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
image_embedding = image_embedding[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
)
vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original image regions
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask,
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
return inputs_embeds
def embed_image_embedding(
inputs_embeds: torch.Tensor,
image_embedding: torch.Tensor,
image_bounds: torch.Tensor,
) -> torch.Tensor:
"""
scatter image_embedding into inputs_embeds according to image_bounds
"""
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(inputs_embeds.device)
inputs_embeds.scatter_(
0,
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
image_embedding.view(-1, image_embedding.shape[-1]),
)
return inputs_embeds
def general_mm_embed_routine(
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
embed_tokens: nn.Embedding,
image_embedding_func: Callable[[ImageInputs], torch.Tensor],
placeholder_token_ids: List[int] = None,
):
"""
a general wrapper function to get final input embeds from multimodal models
with a language model as causal model
"""
if (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
inputs_embeds = embed_tokens(input_ids)
else:
image = forward_batch.merge_image_inputs()
inputs_embeds = embed_image_inputs(
image_input=image,
input_ids=input_ids,
input_embedding=embed_tokens,
image_embedding_func=image_embedding_func,
placeholder_token_ids=placeholder_token_ids,
)
# once used, image_inputs is useless
# just being defensive here
forward_batch.image_inputs = None
return inputs_embeds
......@@ -77,6 +77,7 @@ global_server_args_dict = {
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
}
logger = logging.getLogger(__name__)
......@@ -160,7 +161,8 @@ class ImageInputs:
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None
# [num_of_images, t, h, w]
image_grid_thws: torch.Tensor = None
mrope_position_delta: Optional[torch.Tensor] = None
# Qwen2-VL video related
video_token_id: Optional[int] = None
......@@ -168,7 +170,7 @@ class ImageInputs:
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# deepseek vl2 related
image_seq_mask: Optional[List[torch.Tensor]] = None
images_emb_mask: Optional[List[torch.Tensor]] = None
image_spatial_crop: Optional[List[torch.Tensor]] = None
# The id of the single-image placeholder token
......@@ -182,9 +184,6 @@ class ImageInputs:
slice_end_id: Optional[int] = None
tgt_sizes: Optional[list] = None
# denotes the number of valid image tokens in each image
images_emb_mask: Optional[torch.BoolTensor] = None
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
......@@ -204,7 +203,7 @@ class ImageInputs:
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"image_seq_mask",
"images_emb_mask",
"image_spatial_crop",
"im_token_id",
"im_start_id",
......@@ -212,20 +211,58 @@ class ImageInputs:
"slice_start_id",
"slice_end_id",
"tgt_sizes",
"images_emb_mask",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
# validate
assert (
isinstance(ret.pixel_values, torch.Tensor)
or isinstance(ret.pixel_values, np.ndarray)
or isinstance(ret.pixel_values, list)
)
return ret
def merge(self, other):
def merge(self, other: ImageInputs):
"""
merge image inputs when requests are being merged
"""
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
if isinstance(self.pixel_values, list):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self.pixel_values += other.pixel_values
else:
assert (
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
# args would be stacked along first dim
# usually these are already tensors
stack_args = [
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes",
"image_spatial_crop",
]
for arg in stack_args:
if getattr(self, arg, None) is None:
setattr(self, arg, getattr(other, arg, None))
elif getattr(other, arg, None) is not None:
# self and other both not None
setattr(
self,
arg,
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
)
if self.image_grid_thws is None:
self.image_grid_thws = other.image_grid_thws
elif other.image_grid_thws is not None:
self.image_grid_thws = torch.concat(
[self.image_grid_thws, other.image_grid_thws]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
......@@ -233,7 +270,7 @@ class ImageInputs:
# errors in cuda kernels. See also llava.py for example.
self.image_hashes += other.image_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
# args needed to be merged
optional_args = [
"image_sizes",
"image_offsets",
......@@ -241,13 +278,13 @@ class ImageInputs:
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"image_seq_mask",
"image_spatial_crop",
"images_emb_mask",
]
for arg in optional_args:
if getattr(self, arg, None) is not None:
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
self_arg = getattr(self, arg, None)
if self_arg is not None:
setattr(self, arg, self_arg + getattr(other, arg))
# other args would be kept intact
class Req:
......
......@@ -179,7 +179,7 @@ class TokenizerManager:
)
# We want to parallelize the image pre-processing so we create an executor for it
# We creat image_processor for any skip_tokenizer_init to make sure we still encode
# We create image_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False.
self.image_processor = get_image_processor(
self.model_config.hf_config, server_args, _processor
......
......@@ -332,7 +332,7 @@ class ForwardBatch:
return ret
def get_merged_image_inputs(self) -> Optional[ImageInputs]:
def merge_image_inputs(self) -> Optional[ImageInputs]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
......@@ -358,6 +358,16 @@ class ForwardBatch:
return merged
def contains_image_inputs(self) -> bool:
""" """
if self.image_inputs is None:
return True
return any(
image_input.pixel_values is not None and image_input.pixel_values is not []
for image_input in self.image_inputs
if image_input is not None
)
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
......
......@@ -273,7 +273,7 @@ class ModelRunner:
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for deekseek-vl2."
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment