Unverified Commit ff2ce0b8 authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: move image processors to separate files (#4229)

parent 0f2a2e3c
......@@ -11,11 +11,16 @@ import argparse
import random
import torch
from bench_sglang import EvalArgs, prepare_samples
from data_utils import save_json
from eval_utils import eval_result, get_sampling_params, parse_multi_choice_response
from eval_utils import (
EvalArgs,
eval_result,
get_sampling_params,
prepare_samples,
process_result,
)
from tqdm import tqdm
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig
@torch.no_grad()
......@@ -28,7 +33,6 @@ def eval_mmmu(args):
trust_remote_code=True,
)
model = model.eval().cuda()
model = torch.compile(model)
processor = AutoProcessor.from_pretrained(
args.model_path, torch_dtype="auto", device_map="auto"
......@@ -38,6 +42,10 @@ def eval_mmmu(args):
out_samples = dict()
sampling_params = get_sampling_params(eval_args)
generation_config = GenerationConfig(
max_new_tokens=sampling_params["max_new_tokens"],
do_sample=False,
)
answer_dict = {}
for sample in tqdm(samples):
......@@ -62,7 +70,6 @@ def eval_mmmu(args):
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text],
images=[image],
......@@ -70,13 +77,16 @@ def eval_mmmu(args):
return_tensors="pt",
).to(model.device)
generated_ids = model.generate(**inputs, **sampling_params)
generated_ids = model.generate(
**inputs, generation_config=generation_config
)
response = processor.decode(
generated_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[len(text) :]
print(f"response: {response}")
else: # multiple images actually
if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"]
......@@ -85,24 +95,11 @@ def eval_mmmu(args):
else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
if sample["question_type"] == "multiple-choice":
pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"]
)
else: # open question
pred_ans = response
out_samples[sample["id"]] = pred_ans
torch.cuda.empty_cache()
# set ground truth answer
answer_dict[sample["id"]] = {
"question_type": sample["question_type"],
"ground_truth": sample["answer"],
}
process_result(response, sample, answer_dict, out_samples)
args.output_path = f"{args.model_path}_val_hf.json"
save_json(args.output_path, out_samples)
eval_result(output_path=args.output_path, answer_dict=answer_dict)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
if __name__ == "__main__":
......
......@@ -8,9 +8,9 @@
"""
import argparse
import base64
import dataclasses
import random
import re
from io import BytesIO
from data_utils import save_json
......@@ -18,13 +18,14 @@ from eval_utils import (
EvalArgs,
eval_result,
get_sampling_params,
parse_multi_choice_response,
prepare_samples,
process_result,
)
from tqdm import tqdm
from sglang import Engine
from sglang.srt.conversation import chat_templates
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs
......@@ -35,61 +36,76 @@ def eval_mmmu(args):
if server_args.chat_template is None:
raise ValueError("Chat template must be provided for this benchmark")
samples = prepare_samples(eval_args)
backend = Engine(**dataclasses.asdict(server_args))
out_samples = dict()
sampling_params = get_sampling_params(eval_args)
conv = chat_templates[server_args.chat_template].copy()
image_token = conv.image_token
samples = prepare_samples(eval_args)
answer_dict = {}
for sample in tqdm(samples):
prompt = sample["final_input_prompt"]
image = sample["image"]
bytes_io = BytesIO()
image.save(bytes_io, format="PNG")
png_bytes = bytes_io.getvalue()
prompt = re.sub(r"<[^>]*>", image_token, prompt)
buff = BytesIO()
image.save(buff, format="PNG")
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1]
request_dict = {
"model": "",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prefix,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_str}"
},
},
{
"type": "text",
"text": suffix,
},
],
}
],
}
conv = generate_chat_conv(
ChatCompletionRequest(**request_dict),
template_name=server_args.chat_template,
)
prompt = conv.get_prompt()
if image is not None:
gen_out = backend.generate(
prompt=prompt, image_data=[png_bytes], sampling_params=sampling_params
prompt=prompt,
image_data=conv.image_data,
sampling_params=sampling_params,
)["text"]
response = gen_out
else: # multiple images actually
if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"]
response = random.choice(all_choices)
else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
if sample["question_type"] == "multiple-choice":
pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"]
)
else: # open question
pred_ans = response
out_samples[sample["id"]] = pred_ans
# set ground truth answer
answer_dict[sample["id"]] = {
"question_type": sample["question_type"],
"ground_truth": (
sample["correct_choice"]
if "correct_choice" in samples
else sample["answer"]
),
}
process_result(response, sample, answer_dict, out_samples)
args.output_path = f"{args.model_path}_val_sglang.json"
save_json(args.output_path, out_samples)
eval_result(output_path=args.output_path, answer_dict=answer_dict)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
backend.shutdown()
if __name__ == "__main__":
......
......@@ -143,6 +143,7 @@ def process_single_sample(data):
# DATA SAVING
def save_json(filename, ds):
print(f"answers saved to: {filename}")
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(ds, f, indent=4)
......
......@@ -87,6 +87,7 @@ def set_seed(seed_value):
def prepare_samples(eval_args: EvalArgs):
print("preparing samples...")
# Build prompts
set_seed(eval_args.seed)
......@@ -110,6 +111,7 @@ def prepare_samples(eval_args: EvalArgs):
eval_args.dataset_path, subject, split=eval_args.split
)
sub_dataset_list.append(sub_dataset)
# break
# merge all dataset
dataset = concatenate_datasets(sub_dataset_list)
......@@ -426,9 +428,26 @@ def calculate_ins_level_acc(results: Dict):
return acc / ins_num
def eval_result(output_path, answer_dict):
def process_result(response, sample, answer_dict, out_samples):
if sample["question_type"] == "multiple-choice":
pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"]
)
else: # open question
pred_ans = response
out_samples[sample["id"]] = pred_ans
# set ground truth answer
answer_dict[sample["id"]] = {
"question_type": sample["question_type"],
"ground_truth": sample["answer"],
}
def eval_result(model_answer_path, answer_dict):
print("Evaluating...")
output_dict = json.load(open(output_path))
output_dict = json.load(open(model_answer_path))
# answer_dict = json.load(open(answer_path))
# group by category
......@@ -521,7 +540,7 @@ def eval_result(output_path, answer_dict):
"acc": overall_acc,
}
pprint.pprint(printable_results)
out = output_path
out = model_answer_path
with open(out, "w", encoding="utf-8") as outfile:
json.dump(printable_results, outfile)
print(f"eval out saved to {out}")
......
......@@ -191,7 +191,7 @@ class Conversation:
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
ret += f"[Round {i//2 + round_add_n}]{self.sep}"
ret += f"[Round {i // 2 + round_add_n}]{self.sep}"
if message:
ret += f"{role}{message}{self.sep}"
......@@ -453,7 +453,6 @@ def generate_chat_conv(
conv.system_message = getattr(message.content[0], "text", "")
elif msg_role == "user":
# Handle the various types of Chat Request content types here.
role = conv.roles[0]
if isinstance(message.content, str):
conv.append_message(conv.roles[0], message.content)
else:
......
......@@ -66,6 +66,7 @@ def get_config(
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
......
from __future__ import annotations
from functools import lru_cache
from typing import Optional
from typing import Optional, Tuple
import torch
import torch.nn as nn
......@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.utils import add_prefix
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
# Copied from transformers, modeling_qwen2_vl.py
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class VisionAttention(nn.Module):
......@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
use_context_forward (bool, default to True):
if ``True``, a flash_attn style attention will be applied
Otherwise, a full-sequence attention will be applied.
use_full_precision_softmax (bool, default to False):
if ``True``, the softmax will be performed in full-precision
softmax_in_single_precision (bool, default to False):
if ``True``, the softmax will be performed in single-precision
Otherwise, it will be performed in half-precision
"""
......@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
dropout: float = 0.0,
use_context_forward: bool = True,
use_full_precision_softmax: bool = False,
softmax_in_single_precision: bool = False,
flatten_batch: bool = False,
prefix: str = "",
):
......@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
head_size=self.head_size,
dropout=dropout,
flatten_batch=flatten_batch,
use_full_precision_softmax=use_full_precision_softmax,
softmax_in_single_precision=softmax_in_single_precision,
)
self.use_qkv_parallel = use_qkv_parallel
......@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
rotary_pos_emb: torch.Tensor = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
......@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
x: [b, s, embed_dim]
cu_seqlens: [b]
Returns:
[s, b, num_heads * head]
[s, b, head * head_size]
"""
bsz, s, _ = x.shape
head = self.num_attention_heads_per_partition
if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv, _ = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
q, k, v = [
x.reshape(
bsz * s, self.num_attention_heads_per_partition, -1
).contiguous()
for x in (q, k, v)
]
# [b, s, embed_dim] --> [b * s, head, head_size]
q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
else:
# [b, s, embed_dim] --> [s, b, embed_dim]
x = rearrange(x, "b s ... -> s b ...")
......@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
qkv, _ = self.qkv_proj(x)
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
new_x_shape = qkv.size()[:-1] + (
self.num_attention_heads_per_partition,
head,
3 * self.hidden_size_per_attention_head,
)
qkv = qkv.view(*new_x_shape)
......@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
]
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if position_embeddings is not None:
cos, sin = position_embeddings
original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1)
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape)
if self.use_qkv_parallel:
pass
......@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
head_size: int,
dropout: float = 0.0,
flatten_batch: bool = False,
use_full_precision_softmax: bool = False,
softmax_in_single_precision: bool = False,
):
super().__init__()
self.head_size = head_size
self.flatten_batch = flatten_batch
self.use_full_precision_softmax = use_full_precision_softmax
self.softmax_in_single_precision = softmax_in_single_precision
self.dropout = dropout
@staticmethod
......@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
)
if attention_mask is None:
if self.use_full_precision_softmax:
if self.softmax_in_single_precision:
raise RuntimeError("Empty attention mask")
else:
attention_mask = attention_mask.to(device=q.device)
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
if self.use_full_precision_softmax:
if self.softmax_in_single_precision:
scale = self.head_size**-0.5
k_transposed = rearrange(k, "b h s d -> b h d s")
attn_weights = torch.matmul(q, k_transposed) * scale
......
# TODO: also move pad_input_ids into this module
import asyncio
import concurrent.futures
import dataclasses
import importlib
import logging
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import pkgutil
from functools import lru_cache
import numpy as np
import PIL
import transformers
from decord import VideoReader, cpu
from PIL import Image
from transformers import IMAGE_PROCESSOR_MAPPING
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor,
DummyImageProcessor,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
global global_processor
def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
@dataclasses.dataclass
class BaseImageProcessorOutput:
image_hashes: list[int]
image_sizes: list[int]
all_frames: [PIL.Image]
# input_text, with each frame of video/image represented with a image_token
input_text: str
class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.server_args = server_args
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
)
@abstractmethod
async def process_images_async(
self, image_data, input_text, max_req_input_len, **kwargs
):
pass
def get_estimated_frames_list(self, image_data):
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
def encode_video(self, video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
frame_idx = uniform_sample(frame_idx, frame_count_limit)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_images(
self,
max_req_input_len: int,
input_ids: list,
image_data,
image_token: str,
) -> BaseImageProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
"""
image_hashes, image_sizes = [], []
all_frames = []
new_text_parts = []
if isinstance(input_ids, list):
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
text_parts = input_text.split(image_token)
# roughly calculate the max number of frames under the max_req_input_len limit
def calculate_max_num_frames() -> int:
ret = (max_req_input_len - len(input_ids)) // self.NUM_TOKEN_PER_FRAME
return min(ret, 100)
MAX_NUM_FRAMES = calculate_max_num_frames()
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
# Process each input with allocated frames
for image_index, (image, estimated_frames) in enumerate(
zip(image_data, estimated_frames_list)
):
if len(all_frames) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
frames_to_process = max(1, int(estimated_frames * scaling_factor))
if frames_to_process == 0:
frames = []
else:
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = self.encode_video(
path, frame_count_limit=frames_to_process
)
else:
raw_image, _size = load_image(image)
frames = [raw_image]
if len(frames) == 0:
continue
except FileNotFoundError as e:
print(e)
return None
image_sizes += frames[0].size * len(frames)
image_hashes += [hash(image)] * len(frames)
all_frames += frames
new_text_parts.append(text_parts[image_index])
if frames_to_process != 0:
new_text_parts.append(image_token * len(frames))
assert frames_to_process == len(frames)
new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts)
return BaseImageProcessorOutput(
image_hashes, image_sizes, all_frames, input_text
)
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
image_processor = image_processor or global_processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None
and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, image_processor, image_grid_pinpoints
)
else:
pixel_values = image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
)
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
class MllamaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return global_processor(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
class MiniCPMVImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)"
@staticmethod
def _process_images_task(images, input_text):
result = global_processor.__call__(
text=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": result.pixel_values,
"tgt_sizes": result.tgt_sizes,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MiniCPMVImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
base_output = self.load_images(
max_req_input_len, input_ids, image_data, self.IMAGE_TOKEN
)
if base_output is None:
return None
if len(base_output.all_frames) == 0:
return None
res = await self._process_images(
images=base_output.all_frames, input_text=base_output.input_text
)
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = [tokenizer.im_start_id]
im_end_id = [tokenizer.im_end_id]
if tokenizer.slice_start_id:
slice_start_id = [tokenizer.slice_start_id]
slice_end_id = [tokenizer.slice_end_id]
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"tgt_sizes": res["tgt_sizes"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_processor=None,
):
image_processor = image_processor or global_processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
process_result = image_processor(image)
pixel_values, image_grid_thws = (
process_result["pixel_values"],
process_result["image_grid_thw"][0],
)
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
image_grid_thws = np.stack(image_grid_thws, axis=0)
return pixel_values, image_hash, image_size, image_grid_thws
else:
# It is an image
image_hash = hash(image_data)
process_result = image_processor(image)
pixel_values, image_grid_thws = (
process_result["pixel_values"],
process_result["image_grid_thw"][0],
)
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size, image_grid_thws
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str]):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2VLImageProcessor._process_single_image_task,
image_data,
)
else:
return self._process_single_image_task(image_data)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
pixel_values, image_hashes, image_sizes, image_grid_thws = (
[],
[],
[],
[],
)
res = []
for img_data in image_data:
res.append(self._process_single_image(img_data))
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s, image_thw in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
image_grid_thws.append(image_thw)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.concatenate(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size, image_grid_thw = (
await self._process_single_image(image_data[0])
)
image_hashes = [image_hash]
image_sizes = [image_size]
image_grid_thws = [image_grid_thw]
elif isinstance(image_data, str) or isinstance(image_data, bytes):
# A single image
pixel_values, image_hash, image_size, image_grid_thw = (
await self._process_single_image(image_data)
)
image_hashes = [image_hash]
image_sizes = [image_size]
image_grid_thws = [image_grid_thw]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws,
}
class Qwen2_5VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.NUM_TOKEN_PER_FRAME = 770
@staticmethod
def _process_images_task(images, input_text):
result = global_processor.__call__(
text=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": result.pixel_values,
"image_grid_thws": result.image_grid_thw,
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
)
else:
return self._process_images_task(images, input_text)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
max_req_input_len, input_ids, image_data, image_token
)
ret = await self._process_images(base_output.all_frames, base_output.input_text)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": ret["image_grid_thws"],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
IMAGE_PROCESSOR_MAPPING = {}
def get_image_processor(
hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor:
if "MllamaForConditionalGeneration" in hf_config.architectures:
return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor)
elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures:
return Qwen2_5VLImageProcessor(hf_config, server_args, processor)
elif "MiniCPMV" in hf_config.architectures:
return MiniCPMVImageProcessor(hf_config, server_args, processor)
else:
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
raise ValueError(
f"No image processor found for architecture: {hf_config.architectures}"
)
def get_dummy_image_processor():
return DummyImageProcessor()
@lru_cache()
def import_image_processors():
package_name = "sglang.srt.managers.image_processors"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
continue
if hasattr(module, "ImageProcessorMapping"):
entry = module.ImageProcessorMapping
if isinstance(entry, dict):
for processor_name, cls in entry.items():
IMAGE_PROCESSOR_MAPPING[processor_name] = cls
# also register processors
import_image_processors()
import concurrent
import concurrent.futures
import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional
import PIL
import transformers
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image
global global_processor
def get_global_processor():
global global_processor
return global_processor
@dataclasses.dataclass
class BaseImageProcessorOutput:
image_hashes: list[int]
image_sizes: list[tuple[int, int]]
all_frames: [PIL.Image]
# input_text, with each frame of video/image represented as an image_token
input_text: str
class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.server_args = server_args
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(
self,
server_args,
),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
)
def _build_processor(self, server_args):
"""Init the global processor for multi modal models."""
from sglang.srt.hf_transformers_utils import get_processor
return get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
@abstractmethod
async def process_images_async(
self, image_data, input_text, max_req_input_len, **kwargs
):
pass
def get_estimated_frames_list(self, image_data):
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
@staticmethod
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_images(
self,
input_ids: list,
image_data,
image_token: str,
max_req_input_len: int,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
) -> BaseImageProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
"""
image_hashes, image_sizes = [], []
all_frames = []
new_text_parts = []
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
if return_text:
text_parts = input_text.split(image_token)
# roughly calculate the max number of frames under the max_req_input_len limit
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
assert len(image_data) == len(estimated_frames_list)
# Process each input with allocated frames
for image_index, (image, estimated_frames) in enumerate(
zip(image_data, estimated_frames_list)
):
if len(all_frames) >= MAX_NUM_FRAMES:
max_frames_to_process = 0
else:
max_frames_to_process = max(1, int(estimated_frames * scaling_factor))
if max_frames_to_process == 0:
frames = []
else:
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = BaseImageProcessor.encode_video(
path, frame_count_limit=max_frames_to_process
)
else:
raw_image, _size = load_image(image)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
assert len(frames) != 0
except FileNotFoundError as e:
print(e)
return None
image_sizes += [frames[0].size] * len(frames)
image_hashes += [hash(image)] * len(frames)
all_frames += frames
if return_text:
new_text_parts.append(text_parts[image_index])
if max_frames_to_process != 0:
new_text_parts.append(image_token * len(frames))
assert max_frames_to_process >= len(frames)
if return_text:
new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts)
return BaseImageProcessorOutput(
image_hashes, image_sizes, all_frames, input_text
)
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
def init_global_processor(
sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
):
"""Init the global processor for multi-modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_image_processor._build_processor(server_args=server_args)
import asyncio
from typing import List, Optional, Union
import numpy as np
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
from sglang.srt.models.llavavid import LlavaVidForCausalLM
from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
processor = get_global_processor()
image_processor = image_processor or processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None
and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, image_processor, image_grid_pinpoints
)
else:
pixel_values = image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
)
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
ImageProcessorMapping = {
LlavaVidForCausalLM: LlavaImageProcessor,
LlavaQwenForCausalLM: LlavaImageProcessor,
LlavaMistralForCausalLM: LlavaImageProcessor,
}
import asyncio
from typing import List, Union
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.minicpmv import MiniCPMV
class MiniCPMVImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)"
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(text=input_text, images=images, return_tensors="pt")
return {
"input_ids": result.input_ids,
"pixel_values": result.pixel_values,
"tgt_sizes": result.tgt_sizes,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MiniCPMVImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
base_output = self.load_images(
input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len
)
if base_output is None:
return None
if len(base_output.all_frames) == 0:
return None
res = await self._process_images(
images=base_output.all_frames, input_text=base_output.input_text
)
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = tokenizer.im_start_id
im_end_id = tokenizer.im_end_id
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"tgt_sizes": res["tgt_sizes"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}
import asyncio
from typing import List, Union
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor}
import asyncio
import math
from typing import List, Union
from PIL import Image
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id
self.video_token_id = hf_config.video_token_id
self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28
self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
result = get_global_processor().__call__(
text=[input_text], images=images, padding=True, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"image_grid_thw": getattr(result, "image_grid_thw", None),
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids,
image_data,
image_token,
max_req_input_len,
)
def smart_resize(
height: int,
width: int,
factor: int = self.IMAGE_FACTOR,
min_pixels: int = self.MIN_PIXELS,
max_pixels: int = self.MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > self.MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
width, height = image.size
min_pixels = self.MIN_PIXELS
max_pixels = self.MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.all_frames]
ret = await self._process_images(images, base_output.input_text)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": ret["image_grid_thw"],
"video_grid_thws": ret["video_grid_thws"],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
}
ImageProcessorMapping = {
Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
}
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.utils import logger
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = image_inputs.pad_values
data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = []
if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token:
data_idx += 1
image_inputs.image_offsets += [start_idx]
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids)
return padded_ids
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def __init__(
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
) -> None:
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
......@@ -158,15 +158,19 @@ class ImageInputs:
image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None
# MiniCPMV related
# The id of the single-image placeholder token
im_token_id: Optional[torch.Tensor] = None
# All the images in the batch should share the same special image
# bound token ids.
im_start_id: Optional[torch.Tensor] = None
im_end_id: Optional[torch.Tensor] = None
slice_start_id: Optional[torch.Tensor] = None
slice_end_id: Optional[torch.Tensor] = None
im_start_id: Optional[int] = None
im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None
slice_end_id: Optional[int] = None
tgt_sizes: Optional[list] = None
# denotes the number of valid image tokens in each image
images_emb_mask: Optional[torch.BoolTensor] = None
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
......@@ -186,11 +190,13 @@ class ImageInputs:
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"im_token_id",
"im_start_id",
"im_end_id",
"slice_start_id",
"slice_end_id",
"tgt_sizes",
"images_emb_mask",
]
for arg in optional_args:
if arg in obj:
......
......@@ -455,7 +455,7 @@ def pt_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
state = torch.load(bin_file, map_location="cpu", weights_only=True)
yield from state.items()
del state
torch.cuda.empty_cache()
......
......@@ -41,7 +41,6 @@ from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import (
......@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype
......@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
num_heads_per_partition = divide(self.num_heads, tp_size)
self.self_attn = VisionAttention(
embed_dim=config.hidden_size,
num_heads=num_heads_per_partition,
num_heads=self.num_heads,
projection_size=config.intermediate_size,
use_qkv_parallel=True,
quant_config=quant_config,
dropout=config.attention_dropout,
use_context_forward=False,
use_full_precision_softmax=True,
softmax_in_single_precision=True,
flatten_batch=False,
prefix=add_prefix("self_attn", prefix),
)
......@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
self,
input_ids: torch.Tensor,
pad_values: List[int],
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None,
im_start_id: int,
im_end_id: int,
slice_start_id: Optional[int] = None,
slice_end_id: Optional[int] = None,
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds (start and end token ids) of the images
"""
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
start_cond = input_ids == im_start_id
end_cond = input_ids == im_end_id
if slice_start_id is not None:
start_cond |= input_ids == slice_start_id[0]
end_cond |= input_ids == slice_end_id[0]
start_cond |= input_ids == slice_start_id
end_cond |= input_ids == slice_end_id
(image_start_tokens,) = torch.where(start_cond)
image_start_tokens += 1
......@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
if (
len(image_start_tokens) + 1 == len(image_end_tokens)
and input_ids[0] in pad_values
and len(image_start_tokens) != 0
and len(image_end_tokens) != 0
and image_end_tokens[0] < image_start_tokens[0]
):
image_start_tokens = torch.cat(
......@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch: ForwardBatch,
**kwargs: Any,
) -> torch.Tensor:
if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
None
]:
if (
forward_batch.image_inputs is not None
and len(forward_batch.image_inputs) > 0
and forward_batch.image_inputs[0] is not None
):
# TODO: bath
kwargs.update(
{
"pixel_values": (
......@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
return self.resampler(vision_embedding, tgt_sizes)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
if not isinstance(image_inputs.im_start_id, list) or not isinstance(
image_inputs.im_end_id, list
):
return input_ids
new_input_ids = []
last_idx = 0
image_idx = -1
image_inputs.image_offsets = []
# Get all special token IDs
im_start_id = (
image_inputs.im_start_id[0].item()
if isinstance(image_inputs.im_start_id[0], torch.Tensor)
else image_inputs.im_start_id[0]
)
im_end_id = (
image_inputs.im_end_id[0].item()
if isinstance(image_inputs.im_end_id[0], torch.Tensor)
else image_inputs.im_end_id[0]
)
slice_start_id = (
image_inputs.slice_start_id[0].item()
if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
else image_inputs.slice_start_id[0]
)
slice_end_id = (
image_inputs.slice_end_id[0].item()
if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
else image_inputs.slice_end_id[0]
)
# Find all start and end positions for both types
start_indices = [
i
for i, x in enumerate(input_ids)
if x == im_start_id or x == slice_start_id
]
end_indices = [
i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
]
if len(start_indices) != len(end_indices):
return input_ids
# Process each region (both image and slice)
for start_idx, end_idx in zip(start_indices, end_indices):
# Add non-image tokens before this region
new_input_ids.extend(
input_ids[last_idx : start_idx + 1]
) # include start token
is_image_start = input_ids[start_idx] == im_start_id
if is_image_start:
image_inputs.image_offsets += [start_idx]
image_idx += 1
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
# Generate pad_ids
pad_values = [image_inputs.pad_values[image_idx]]
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
pad_ids = pad_ids[:num_tokens]
# Add pad_ids
new_input_ids.extend(pad_ids)
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
slice_start_id: int = image_inputs.slice_start_id
slice_end_id: int = image_inputs.slice_end_id
# Update last_idx to after end token
last_idx = end_idx
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
# Add remaining tokens after last region
new_input_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(new_input_ids)
return new_input_ids
return pattern.pad_input_tokens(input_ids, image_inputs)
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
......
......@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module):
quant_config=None,
dropout=0.0,
use_context_forward=False,
use_full_precision_softmax=False,
softmax_in_single_precision=False,
flatten_batch=False,
prefix=add_prefix("self_attn", prefix),
)
......
......@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
if attn_implementation == "sdpa":
use_context_forward = False
use_full_precision_softmax = False
softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2":
use_full_precision_softmax = False
softmax_in_single_precision = False
use_context_forward = True
elif attn_implementation == "eager":
use_full_precision_softmax = True
softmax_in_single_precision = True
use_context_forward = False
self.attn = VisionAttention(
......@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size=dim,
use_qkv_parallel=False,
use_context_forward=use_context_forward,
use_full_precision_softmax=use_full_precision_softmax,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
......@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
)
def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
......@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
......@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
x = blk(
x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
)
# adapter
x = self.merger(x)
......@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return num_image_tokens
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
new_input_ids = []
last_idx = 0
image_idx = -1
image_inputs.image_offsets = []
# Get all special token IDs
im_start_id = image_inputs.im_start_id
im_end_id = image_inputs.im_end_id
# Find all start and end positions for both types
start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
if len(start_indices) != len(end_indices):
return input_ids
# Process each region (both image and slice)
for start_idx, end_idx in zip(start_indices, end_indices):
# Add non-image tokens before this region
new_input_ids.extend(input_ids[last_idx : start_idx + 1])
is_image_start = input_ids[start_idx] == im_start_id
if is_image_start:
image_inputs.image_offsets += [start_idx]
image_idx += 1
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
# Generate pad_ids
pad_values = [image_inputs.pad_values[image_idx]]
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
pad_ids = pad_ids[:num_tokens]
# Add pad_ids
new_input_ids.extend(pad_ids)
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
# Update last_idx to after end token
last_idx = end_idx
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
# Add remaining tokens after last region
new_input_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(new_input_ids)
return new_input_ids
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
......@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None:
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
......
......@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module):
mlp_hidden_dim = int(dim * mlp_ratio)
if attn_implementation == "sdpa":
use_context_forward = False
use_full_precision_softmax = False
softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2":
use_full_precision_softmax = False
softmax_in_single_precision = False
use_context_forward = True
elif attn_implementation == "eager":
use_full_precision_softmax = True
softmax_in_single_precision = True
use_context_forward = False
self.attn = VisionAttention(
......@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module):
projection_size=dim,
use_qkv_parallel=False,
use_context_forward=use_context_forward,
use_full_precision_softmax=use_full_precision_softmax,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
......@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module):
)
def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
......@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module):
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
......@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module):
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
# adapter
x = self.merger(x)
......@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
return num_image_tokens
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == self.config.image_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[image_cnt]
)
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
def __init__(
self,
config: Qwen2VLConfig,
......@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
......@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None:
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.clone()
pixel_values = torch.tensor(image.pixel_values, device="cuda")
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
......@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len)
right_idx = (
start_idx + (image_offset - prefix_len) + num_image_tokens
)
left_idx = start_idx + (image_offset - prefix_len + 1)
right_idx = left_idx + num_image_tokens
inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens
]
image_embeds_offset += num_image_tokens
input_ids = None
hidden_states = self.model(
input_ids=input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment