Commit 8e55a526 authored by Jin Zhen Jiang's avatar Jin Zhen Jiang
Browse files

feat: add mineru-vlm backend.

parent 6f8a9610
import ast
import math
import re
from functools import partial, reduce
from typing import Dict, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
to_numpy_array,
)
from transformers.utils import TensorType
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
original_width, original_height = original_size
best_fit = (0, 0)
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def divide_to_patches(image, patch_size):
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
if pil_img.mode == "L":
pil_img = pil_img.convert("RGB")
if width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
# This functions is not used.
def resize_and_pad_image(image, target_resolution):
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
# DIFFERENT from sglang.srt.mm_utils.process_anyres_image
def process_anyres_image(image, processor, grid_pinpoints):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
patch_size = processor.crop_size["height"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
best_resolution = select_best_resolution(image.size, possible_resolutions)
# image_padded = resize_and_pad_image(image, best_resolution)
image_padded = image.resize(best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size["height"])
image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", "")
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
else:
return image_processor(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
class Mineru2ImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
size=(384, 384),
crop_size: Optional[Dict[str, int]] = None,
resample=PILImageResampling.BICUBIC,
rescale_factor=1 / 255,
data_format=ChannelDimension.FIRST,
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[list] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
self.image_aspect_ratio = image_aspect_ratio
self.image_grid_pinpoints = image_grid_pinpoints
self.in_e2e_processing = False
def _preprocess(self, images):
if isinstance(images, Image.Image):
images = [images]
else:
# to adapt video data
images = [to_numpy_array(image) for image in images]
assert isinstance(images, list)
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
return {"pixel_values": images}
def _preprocess_end_to_end(self, images):
image_aspect_ratio = self.image_aspect_ratio
image_grid_pinpoints = self.image_grid_pinpoints
assert image_aspect_ratio is not None
assert image_grid_pinpoints is not None
pixel_values = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in self.image_mean))
image = self._preprocess(image)["pixel_values"][0]
pixel_values.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, self, self.image_grid_pinpoints)
pixel_values.append(image.numpy())
else:
pixel_values = self._preprocess(images)["pixel_values"]
if isinstance(pixel_values, list) and all(x.shape == pixel_values[0].shape for x in pixel_values):
pixel_values = np.stack(pixel_values, axis=0)
# CAUTION: here used (height, width).
image_sizes = [(image.height, image.width) for image in images]
assert len(pixel_values) == len(image_sizes)
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
def preprocess(
self,
images,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
if self.image_aspect_ratio is None or self.in_e2e_processing:
data = self._preprocess(images)
else:
assert self.image_grid_pinpoints is not None
self.in_e2e_processing = True
try:
data = self._preprocess_end_to_end(images)
finally:
self.in_e2e_processing = False
return BatchFeature(data=data, tensor_type=return_tensors)
import math
import re
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import (
Qwen2ForCausalLM,
Qwen2Model,
SiglipVisionConfig,
SiglipVisionModel,
)
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mineru2 import Mineru2QwenConfig
from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
class SiglipVisionTower(nn.Module):
def __init__(self, vision_tower):
super().__init__()
self.config = SiglipVisionConfig.from_pretrained(vision_tower)
assert isinstance(self.config, SiglipVisionConfig)
self.config.num_hidden_layers -= 1 # drop the last hidden layer
self.config.vision_use_head = False
self.vision_tower = SiglipVisionModel(self.config)
self.vision_tower.requires_grad_(False)
self.image_processor = Mineru2ImageProcessor()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
)
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def image_size(self):
return self.config.image_size
def build_vision_tower(config: Mineru2QwenConfig):
vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
if "siglip" in vision_tower.lower():
return SiglipVisionTower(vision_tower)
raise ValueError(f"Unknown vision tower: {vision_tower}")
def build_vision_projector(config: Mineru2QwenConfig):
projector_type = getattr(config, "mm_projector_type", "linear")
if projector_type == "linear":
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU()) # type: ignore
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == "identity":
return nn.Identity()
raise ValueError(f"Unknown projector type: {projector_type}")
class Mineru2QwenModel(Qwen2Model):
config_class = Mineru2QwenConfig
def __init__(self, config: Mineru2QwenConfig):
super(Mineru2QwenModel, self).__init__(config)
self.vision_tower = build_vision_tower(config)
self.mm_projector = build_vision_projector(config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
config_class = Mineru2QwenConfig
def __init__(self, config: Mineru2QwenConfig):
super(Qwen2ForCausalLM, self).__init__(config)
config.rope_scaling = None
self.model = Mineru2QwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.ignore_index = config.ignore_index
self.image_token_index = config.image_token_index
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def encode_images(self, images: torch.Tensor):
image_features = self.get_model().vision_tower(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
):
vision_tower = self.get_model().vision_tower
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
if mm_patch_merge_type == "flat":
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_model().vision_tower.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.get_model().vision_tower.config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
raise NotImplementedError
if (
"unpad" in mm_patch_merge_type
and "anyres_max" in image_aspect_ratio
and matched_anyres_max_num_patches
):
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature, [int(h // times), int(w // times)], mode="bilinear"
)[0]
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif "unpad" in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if "unpad" in mm_patch_merge_type:
image_feature = torch.cat(
(image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
else:
image_features = self.encode_images(images)
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, self.ignore_index)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == self.image_token_index).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = (
[-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
)
)
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
)
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[List[List[int]]] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
from sglang.srt.configs.model_config import multimodal_model_archs
from sglang.srt.models.registry import ModelRegistry
try:
# sglang==0.4.5.post3
from sglang.srt.managers.multimodal_processor import (
PROCESSOR_MAPPING as PROCESSOR_MAPPING,
)
except ImportError:
# sglang==0.4.4.post1
from sglang.srt.managers.image_processor import (
IMAGE_PROCESSOR_MAPPING as PROCESSOR_MAPPING,
)
from .. import vlm_hf_model as _
from .image_processor import Mineru2ImageProcessor
from .model import Mineru2QwenForCausalLM
ModelRegistry.models[Mineru2QwenForCausalLM.__name__] = Mineru2QwenForCausalLM
PROCESSOR_MAPPING[Mineru2QwenForCausalLM] = Mineru2ImageProcessor
multimodal_model_archs.append(Mineru2QwenForCausalLM.__name__)
import asyncio
import time
from types import MethodType
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
import fastapi
from sglang.srt.entrypoints.engine import Engine as _Engine
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.tokenizer_manager import (
TokenizerManager,
dataclass_to_string_truncated,
logger,
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from ...utils.run_async import run_async
from .logit_processor import Mineru2LogitProcessor
class BatchEngine(_Engine):
"""
The engine is patched to support batch multi-modal generate, and early image preprocessing.
"""
def __init__(self, server_args: ServerArgs, **kwargs):
server_args.enable_custom_logit_processor = True
super().__init__(server_args=server_args, **kwargs)
_patch_tokenizer_manager(self.tokenizer_manager)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
# EDIT
if isinstance(image_data, list):
for _ in range(len(image_data)):
modalities_list.append(["image"])
elif image_data is not None:
modalities_list.append("image")
# ADD
if custom_logit_processor is None:
custom_logit_processor = Mineru2LogitProcessor().to_str()
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
generator = _generate_request(self.tokenizer_manager, obj, None)
if stream:
def generator_wrapper():
while True:
try:
chunk = run_async(generator.__anext__())
yield chunk
except StopAsyncIteration:
break
return generator_wrapper()
else:
ret = run_async(generator.__anext__())
return ret
async def async_generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, AsyncIterator[Dict], Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
# EDIT
if isinstance(image_data, list):
for _ in range(len(image_data)):
modalities_list.append(["image"])
elif image_data is not None:
modalities_list.append("image")
# ADD
if custom_logit_processor is None:
custom_logit_processor = Mineru2LogitProcessor().to_str()
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
generator = _generate_request(self.tokenizer_manager, obj, None)
if stream is True:
return generator
else:
return await generator.__anext__()
def _auto_create_handle_loop(self: TokenizerManager):
"""
patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
when the event loop changes.
"""
try:
curr_handle_loop = asyncio.get_running_loop()
except RuntimeError:
curr_handle_loop = None
last_handle_loop = getattr(self, "_last_handle_loop", None)
if last_handle_loop != curr_handle_loop:
self.no_create_loop = False
setattr(self, "_last_handle_loop", curr_handle_loop)
return TokenizerManager.auto_create_handle_loop(self)
def _patch_tokenizer_manager(self: TokenizerManager):
self.auto_create_handle_loop = MethodType(_auto_create_handle_loop, self)
async def _one_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request],
created_time: Optional[float],
):
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for out in self._wait_one_response(obj, request):
yield out
async def _handle_batch_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
):
batch_size = obj.batch_size
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) != 1:
raise Exception("parallel_sample_num != 1 is not supported in this patched code.")
# Send all requests
for i in range(batch_size):
tmp_obj = obj[i]
generators.append(_one_request(self, tmp_obj, request, created_time))
rids.append(tmp_obj.rid)
# Wait for all requests
is_stream = hasattr(obj, "stream") and obj.stream
if not is_stream:
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
yield outputs
else:
rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map:
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
gen = task_map.pop(task)
try:
result = task.result()
result["index"] = rid_to_index[result["meta_info"]["id"]]
yield result
new_task = asyncio.create_task(gen.__anext__())
task_map[new_task] = gen
except StopAsyncIteration:
pass
async def _generate_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
self.auto_create_handle_loop()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj.normalize_batch_and_arguments()
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}")
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, request):
yield response
else:
async for response in _handle_batch_request(self, obj, request, created_time):
yield response
import ast
import asyncio
import re
from typing import List, Optional, Union
import numpy as np
try:
# sglang==0.4.5.post3
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as BaseProcessor,
)
get_global_processor = None
except ImportError:
# sglang==0.4.4.post1
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor as BaseProcessor,
get_global_processor,
)
from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
from .model import Mineru2QwenForCausalLM
# image_best_res is only resized (not padded).
def process_anyres_image(image, processor, grid_pinpoints):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
patch_size = processor.crop_size["height"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_best_res = image.resize(best_resolution) # <<<<<<< Here changed
patches = divide_to_patches(image_best_res, processor.crop_size["height"])
image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
return np.stack(image_patches, axis=0)
class Mineru2ImageProcessor(BaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
if image_processor is None:
assert get_global_processor is not None
image_processor = get_global_processor().image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
else:
pixel_values = image_processor(image)["pixel_values"][0]
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
if hasattr(self, "cpu_executor"):
executor = self.cpu_executor
else:
executor = self.executor
if get_global_processor is not None:
image_processor = None # save ipc cost
else:
image_processor = self._processor.image_processor
if executor is not None:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
executor,
Mineru2ImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
image_processor,
)
else:
return self._process_single_image_task(
image_data,
aspect_ratio,
grid_pinpoints,
image_processor,
)
# sglang==0.4.4.post1
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", "")
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints") and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(self._process_single_image(img_data, aspect_ratio, grid_pinpoints))
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
# sglang==0.4.5.post3
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
result = await self.process_images_async(image_data, input_text, request_obj, *args, **kwargs)
if result is None:
return None
modality = Modality.IMAGE
if isinstance(request_obj.modalities, list):
if request_obj.modalities[0] == "multi-images":
modality = Modality.MULTI_IMAGES
elif request_obj.modalities[0] == "video":
modality = Modality.VIDEO
return {
"mm_items": [
MultimodalDataItem(
pixel_values=result["pixel_values"],
image_sizes=result["image_sizes"],
modality=modality,
)
],
}
ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}
from typing import List
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
class Mineru2LogitProcessor(CustomLogitProcessor):
"""
Stateless logit processor for Mineru2.
(base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
This processor applies token-level constraints to prevent repetition during generation.
It supports two main constraints:
- no_repeat_ngram_size (int):
Prevents repeating the same n-gram of specified size in the output.
Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
This implementation is slower due to its lack of specialized optimization.
- no_repeat_token_count (int):
(Placeholder for future logic)
Intended to prevent repeating the same token multiple times.
Not yet implemented in this version.
"""
def __init__(self) -> None:
super().__init__()
self._generated_ngrams = {} # Cache of generated n-grams by request ID
self._time = {} # Timestamp of the last update for each request
self._gen_step = 0 # Global generation step counter
def __call__(self, logits, batch_info: List[dict]):
"""
Applies repetition constraints to the logits before sampling tokens.
Args:
logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
- "__req__": Request object containing request ID and output_ids.
- "no_repeat_ngram_size": Size of n-gram to avoid repeating.
Returns:
FloatTensor: The modified logits tensor with banned token logits set to -inf.
"""
from sglang.srt.managers.schedule_batch import Req
self._gen_step += 1 # Update global generation step
for idx, info in enumerate(batch_info):
if not isinstance(info, dict) or "__req__" not in info:
continue
req: Req = info["__req__"]
rid = req.rid
output_ids = req.output_ids
ngram_size = info.get("no_repeat_ngram_size", 0)
# Skip if there are not enough tokens to form an n-gram
if ngram_size <= 0 or len(output_ids) < ngram_size:
continue
# Record the current step for cache cleanup tracking
self._time[rid] = self._gen_step
# Initialize n-gram cache for this request if it doesn't exist
if rid not in self._generated_ngrams:
self._generated_ngrams[rid] = {}
# Get the n-gram prefix (all but the last token)
prev_ngram = tuple(output_ids[-ngram_size:-1])
last_token = output_ids[-1]
# Store this n-gram occurrence
self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
# Get the next-token candidates to ban based on current prefix
current_prefix = tuple(output_ids[-ngram_size + 1 :])
banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
# Set the logits of banned tokens to negative infinity
for token in banned_tokens:
logits[idx][token] = -float("inf")
# Clean up cache for expired requests
expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
for rid in expired_rids:
self._generated_ngrams.pop(rid, None)
self._time.pop(rid, None)
return logits
import math
import re
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, # unpad_image, unpad_image_shape
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
from torch import nn
from transformers import (
CLIPVisionConfig,
CLIPVisionModel,
SiglipVisionConfig,
SiglipVisionModel,
)
from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
def flatten_nested_list(nested_list):
if isinstance(nested_list, list):
return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
else:
return [nested_list]
def downgrade_modality(modality):
modality_str = str(modality)
if "MULTI_IMAGES" in modality_str:
return "multi-images"
if "IMAGE" in modality_str:
return "image"
if "VIDEO" in modality_str:
return "video"
if "AUDIO" in modality_str:
return "audio"
raise ValueError(f"Unexpected modality: {modality_str}")
class Mineru2QwenForCausalLM(nn.Module):
def __init__(
self,
config: Mineru2QwenConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 151646
# load vision tower
mm_vision_tower = self.config.mm_vision_tower
if "clip" in mm_vision_tower:
vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
self.vision_tower = CLIPVisionModel(vision_config) # type: ignore
elif "siglip" in mm_vision_tower:
vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
self.vision_tower = SiglipVisionModel(vision_config) # type: ignore
# Siglip needs all feature tokens
self.config.mm_vision_select_feature = "full"
else:
raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")
### EDIT: change projector
# the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
self.multi_modal_mlp = build_vision_projector(config)
self.language_model = Qwen2ForCausalLM(
config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))
language_model_device = next(self.language_model.parameters()).device
self.vision_tower = self.vision_tower.to(language_model_device)
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
if self.vision_feature_select_strategy in ("patch", "full"):
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
def pad_input_ids(self, input_ids: List[int], image_inputs):
if hasattr(image_inputs, "mm_items"): # MultimodalInputs
# sglang==0.4.5.post3
image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
pad_values = [item.pad_value for item in image_inputs.mm_items]
else: # ImageInputs
# sglang==0.4.4.post1
image_sizes = image_inputs.image_sizes
pad_values = image_inputs.pad_values
# hardcode for spatial_unpad + anyres
# if image_inputs.modalities is not None and (
# "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
# ):
# image_aspect_ratio = "pad"
# else:
# image_aspect_ratio = "anyres"
offset_list = []
image_inputs.image_pad_len = []
for image_idx, image_s in enumerate(image_sizes):
if len(image_sizes) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
else:
new_image_feature_len = self.image_feature_len # multiimage
height = width = self.num_patches_per_side
if "anyres" in self.config.image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_s,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
h = num_patch_height * height
w = num_patch_width * width
### EDIT: remove `unpad_image_shape`
# new_h, new_w = unpad_image_shape(h, w, image_s)
new_h, new_w = h, w
if "anyres_max" in self.config.image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
if times > 1.1:
new_h = int(new_h // times)
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)
try:
offset = input_ids.index(self.config.image_token_index)
except ValueError:
offset = 0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
offset_list.append(offset)
image_inputs.image_pad_len.append(new_image_feature_len)
image_inputs.image_offsets = offset_list
return input_ids
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
image_features = self.multi_modal_mlp(selected_image_feature)
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hasattr(forward_batch, "mm_inputs"):
# sglang==0.4.5.post3
image_inputs = forward_batch.mm_inputs
is_sglang_mm_inputs = True
else:
# sglang==0.4.4.post1
image_inputs = forward_batch.image_inputs
is_sglang_mm_inputs = False
if image_inputs is None:
image_inputs = []
if forward_batch.forward_mode.is_extend():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
max_image_offset = []
for im in image_inputs:
if im:
if hasattr(im, "mm_items"):
# sglang==0.4.5.post3
modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
elif im.modalities is not None:
# sglang==0.4.4.post1
modalities_list.extend(im.modalities)
if im and im.image_offsets:
max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
else:
max_image_offset.append(-1)
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
bs = forward_batch.batch_size
if is_sglang_mm_inputs:
# sglang==0.4.5.post3
pixel_values = flatten_nested_list(
[[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
) # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
image_sizes = [
flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
for i in range(bs)
if need_vision[i]
] # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
else:
# sglang==0.4.4.post1
pixel_values = [image_inputs[i].pixel_values for i in range(bs) if need_vision[i]]
image_sizes = [image_inputs[i].image_sizes for i in range(bs) if need_vision[i]]
########## Encode Image ########
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096
if self.mm_patch_merge_type.startswith("spatial"):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if modalities_list[image_idx] == "image":
image_aspect_ratio = self.config.image_aspect_ratio # single image
elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = (
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
# )
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image"
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
vision_tower_image_size = self.image_size
try:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx][0],
self.config.image_grid_pinpoints,
vision_tower_image_size,
)
except Exception as e:
print(f"Error: {e}")
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
image_feature = image_feature.view(2, 2, height, width, -1)
if "unpad" in self.mm_patch_merge_type:
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
### EDIT: remove `unpad_image`
# image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature,
[int(h // times), int(w // times)],
mode="bilinear",
)[0]
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
image_feature = image_feature.unsqueeze(0)
else:
if modalities_list[image_idx] == "video": # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(num_of_frames, height, width, -1)
image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # N, C, H, W
height, weight = image_feature.shape[2:]
scaled_shape = [
math.ceil(height / 2),
math.ceil(weight / 2),
]
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
image_feature = image_feature.flatten(2).transpose(1, 2).contiguous() # N, C, H*W
if "unpad" in self.mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
# Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
self.language_model.model.image_newline[None, None].expand(
image_feature.shape[0],
1,
image_feature.shape[-1],
),
),
dim=1,
)
new_image_features.append(image_feature)
image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
seq_len = extend_seq_lens[i]
prefix_len = prefix_lens_cpu[i]
# Multiple images
for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
continue
if image_offset >= prefix_len + seq_len:
break
tmp_image_feature = image_features[pt][image_idx]
pad_len = tmp_image_feature.shape[0]
input_offset = image_offset - prefix_len
left_idx = start_idx + input_offset
right_idx = left_idx + pad_len
assert right_idx > start_idx
if input_offset < 0:
left_idx = start_idx
tmp_image_feature = tmp_image_feature[-input_offset:]
if right_idx > start_idx + seq_len:
tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
right_idx = start_idx + seq_len
try:
input_embeds[left_idx:right_idx] = tmp_image_feature
except RuntimeError as e:
print(f"RuntimeError in image encoding: {e}")
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
pt += 1
return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
else:
raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
projector_weights = {
"model.mm_projector": "multi_modal_mlp",
"model.vision_tower.vision_tower": "vision_tower",
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "projector" in name or "vision_tower" in name or "image_newline" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
else:
self.language_model.load_weights([(name, loaded_weight)])
@property
def num_patches_per_side(self):
return self.image_size // self.patch_size
EntryClass = [Mineru2QwenForCausalLM]
import os
import sys
from fastapi import Request
from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
from .logit_processor import Mineru2LogitProcessor
_custom_logit_processor_str = Mineru2LogitProcessor().to_str()
# remote the existing /generate route
for route in app.routes[:]:
if hasattr(route, "path") and getattr(route, "path") == "/generate":
app.routes.remove(route)
# add the custom /generate route
@app.api_route("/generate", methods=["POST", "PUT"])
async def custom_generate_request(obj: GenerateReqInput, request: Request):
if obj.custom_logit_processor is None:
obj.custom_logit_processor = _custom_logit_processor_str
return await generate_request(obj, request)
def main():
server_args = prepare_server_args(sys.argv[1:])
if server_args.chat_template is None:
server_args.chat_template = "chatml"
server_args.enable_custom_logit_processor = True
try:
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
if __name__ == "__main__":
main()
# Copyright (c) Opendatalab. All rights reserved.
import base64
from io import BytesIO
from loguru import logger
from PIL import Image
from pypdfium2 import PdfBitmap, PdfDocument, PdfPage
def page_to_image(
page: PdfPage,
dpi: int = 144, # changed from 200 to 144
max_width_or_height: int = 2560, # changed from 4500 to 2560
) -> (Image.Image, float):
scale = dpi / 72
long_side_length = max(*page.get_size())
if long_side_length > max_width_or_height:
scale = max_width_or_height / long_side_length
bitmap: PdfBitmap = page.render(scale=scale) # type: ignore
try:
image = bitmap.to_pil()
finally:
try:
bitmap.close()
except Exception:
pass
return image, scale
def image_to_bytes(
image: Image.Image,
image_format: str = "PNG", # 也可以用 "JPEG"
) -> bytes:
with BytesIO() as image_buffer:
image.save(image_buffer, format=image_format)
return image_buffer.getvalue()
def image_to_b64str(
image: Image.Image,
image_format: str = "PNG", # 也可以用 "JPEG"
) -> str:
image_bytes = image_to_bytes(image, image_format)
return base64.b64encode(image_bytes).decode("utf-8")
def pdf_to_images(
pdf: str | bytes | PdfDocument,
dpi: int = 144,
max_width_or_height: int = 2560,
start_page_id: int = 0,
end_page_id: int | None = None,
) -> list[Image.Image]:
doc = pdf if isinstance(pdf, PdfDocument) else PdfDocument(pdf)
page_num = len(doc)
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else page_num - 1
if end_page_id > page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = page_num - 1
images = []
try:
for i in range(start_page_id, end_page_id + 1):
image, _ = page_to_image(doc[i], dpi, max_width_or_height)
images.append(image)
finally:
try:
doc.close()
except Exception:
pass
return images
def pdf_to_images_bytes(
pdf: str | bytes | PdfDocument,
dpi: int = 144,
max_width_or_height: int = 2560,
start_page_id: int = 0,
end_page_id: int | None = None,
image_format: str = "PNG",
) -> list[bytes]:
images = pdf_to_images(pdf, dpi, max_width_or_height, start_page_id, end_page_id)
return [image_to_bytes(image, image_format) for image in images]
def pdf_to_images_b64strs(
pdf: str | bytes | PdfDocument,
dpi: int = 144,
max_width_or_height: int = 2560,
start_page_id: int = 0,
end_page_id: int | None = None,
image_format: str = "PNG",
) -> list[str]:
images = pdf_to_images(pdf, dpi, max_width_or_height, start_page_id, end_page_id)
return [image_to_b64str(image, image_format) for image in images]
import asyncio
import threading
from queue import Queue
from typing import Any, AsyncIterable, Coroutine, Iterable, TypeVar
T = TypeVar("T")
def run_async(coroutine: Coroutine[Any, Any, T]) -> T:
if not asyncio.iscoroutine(coroutine):
raise ValueError("a coroutine was expected, got {!r}".format(coroutine))
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None:
return loop.run_until_complete(coroutine)
else:
return asyncio.run(coroutine)
def iter_async(iterable: AsyncIterable[T]) -> Iterable[T]:
if not isinstance(iterable, AsyncIterable):
raise ValueError("an async iterable was expected, got {!r}".format(iterable))
queue = Queue()
async def async_helper():
try:
async for chunk in iterable:
queue.put(chunk)
queue.put(None)
except Exception as e:
queue.put(e)
def helper():
run_async(async_helper())
thread = threading.Thread(target=helper, daemon=True)
thread.start()
while True:
chunk = queue.get()
if chunk is None:
break
if isinstance(chunk, Exception):
raise chunk
yield chunk
thread.join()
[tool.black]
line-length = 128
[tool.ruff]
line-length = 128
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment