Unverified Commit 0c7a0882 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2611 from myhloli/dev

Dev
parents 3bd0ecf1 a392f445
# Copyright (c) Opendatalab. All rights reserved.
import os import os
from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
import torch
from loguru import logger from loguru import logger
from rapid_table import RapidTable, RapidTableInput from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType
from magic_pdf.libs.config_reader import get_device from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
class RapidTableModel(object): class RapidTableModel(object):
def __init__(self, ocr_engine, table_sub_model_name='slanet_plus'): def __init__(self, ocr_engine):
sub_model_list = [model.value for model in ModelType] slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
if table_sub_model_name is None: input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
else:
root_dir = Path(__file__).absolute().parent.parent.parent.parent.parent
slanet_plus_model_path = os.path.join(root_dir, 'resources', 'slanet_plus', 'slanet-plus.onnx')
input_args = RapidTableInput(model_type=table_sub_model_name, model_path=slanet_plus_model_path)
else:
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
self.table_model = RapidTable(input_args) self.table_model = RapidTable(input_args)
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
# from rapidocr_paddle import RapidOCR
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
# else:
# from rapidocr_onnxruntime import RapidOCR
# self.ocr_engine = RapidOCR()
# self.ocr_model_name = "PaddleOCR"
self.ocr_engine = ocr_engine self.ocr_engine = ocr_engine
......
from transformers import AutoConfig, AutoImageProcessor, AutoModelForCausalLM
from .configuration_mineru2 import Mineru2QwenConfig
from .image_processing_mineru2 import Mineru2ImageProcessor
from .modeling_mineru2 import Mineru2QwenForCausalLM
AutoConfig.register(Mineru2QwenConfig.model_type, Mineru2QwenConfig)
AutoModelForCausalLM.register(Mineru2QwenConfig, Mineru2QwenForCausalLM)
AutoImageProcessor.register(Mineru2QwenConfig, slow_image_processor_class=Mineru2ImageProcessor)
from transformers import Qwen2Config
class Mineru2QwenConfig(Qwen2Config):
model_type = "mineru2_qwen"
def __init__(
self,
ignore_index=-100,
image_aspect_ratio="square_anyres_max_9",
image_grid_pinpoints="(1x1),...,(4x4)",
image_token_index=151646,
mm_hidden_size=1152,
mm_patch_merge_type="spatial_unpad",
mm_projector_type="mlp2x_gelu",
mm_vision_select_feature="full",
mm_vision_select_layer=-2,
mm_vision_tower="google/siglip-so400m-patch14-384",
tie_word_embeddings=False,
tokenizer_model_max_length=16384,
tokenizer_padding_side="right",
unfreeze_mm_vision_tower=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_aspect_ratio = image_aspect_ratio
self.image_grid_pinpoints = image_grid_pinpoints
self.image_token_index = image_token_index
self.mm_hidden_size = mm_hidden_size
self.mm_patch_merge_type = mm_patch_merge_type
self.mm_projector_type = mm_projector_type
self.mm_vision_select_feature = mm_vision_select_feature
self.mm_vision_select_layer = mm_vision_select_layer
self.mm_vision_tower = mm_vision_tower
self.tokenizer_model_max_length = tokenizer_model_max_length
self.tokenizer_padding_side = tokenizer_padding_side
self.unfreeze_mm_vision_tower = unfreeze_mm_vision_tower
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
import ast
import math
import re
from functools import partial, reduce
from typing import Dict, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
to_numpy_array,
)
from transformers.utils import TensorType
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
original_width, original_height = original_size
best_fit = (0, 0)
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def divide_to_patches(image, patch_size):
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
if pil_img.mode == "L":
pil_img = pil_img.convert("RGB")
if width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
# This functions is not used.
def resize_and_pad_image(image, target_resolution):
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
# DIFFERENT from sglang.srt.mm_utils.process_anyres_image
def process_anyres_image(image, processor, grid_pinpoints):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
patch_size = processor.crop_size["height"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
best_resolution = select_best_resolution(image.size, possible_resolutions)
# image_padded = resize_and_pad_image(image, best_resolution)
image_padded = image.resize(best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size["height"])
image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", "")
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
else:
return image_processor(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
class Mineru2ImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
size=(384, 384),
crop_size: Optional[Dict[str, int]] = None,
resample=PILImageResampling.BICUBIC,
rescale_factor=1 / 255,
data_format=ChannelDimension.FIRST,
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[list] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
self.image_aspect_ratio = image_aspect_ratio
self.image_grid_pinpoints = image_grid_pinpoints
self.in_e2e_processing = False
def _preprocess(self, images):
if isinstance(images, Image.Image):
images = [images]
else:
# to adapt video data
images = [to_numpy_array(image) for image in images]
assert isinstance(images, list)
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
return {"pixel_values": images}
def _preprocess_end_to_end(self, images):
image_aspect_ratio = self.image_aspect_ratio
image_grid_pinpoints = self.image_grid_pinpoints
assert image_aspect_ratio is not None
assert image_grid_pinpoints is not None
pixel_values = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in self.image_mean))
image = self._preprocess(image)["pixel_values"][0]
pixel_values.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, self, self.image_grid_pinpoints)
pixel_values.append(image.numpy())
else:
pixel_values = self._preprocess(images)["pixel_values"]
if isinstance(pixel_values, list) and all(x.shape == pixel_values[0].shape for x in pixel_values):
pixel_values = np.stack(pixel_values, axis=0)
# CAUTION: here used (height, width).
image_sizes = [(image.height, image.width) for image in images]
assert len(pixel_values) == len(image_sizes)
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
def preprocess(
self,
images,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
if self.image_aspect_ratio is None or self.in_e2e_processing:
data = self._preprocess(images)
else:
assert self.image_grid_pinpoints is not None
self.in_e2e_processing = True
try:
data = self._preprocess_end_to_end(images)
finally:
self.in_e2e_processing = False
return BatchFeature(data=data, tensor_type=return_tensors)
import math
import re
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import (
Qwen2ForCausalLM,
Qwen2Model,
SiglipVisionConfig,
SiglipVisionModel,
)
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mineru2 import Mineru2QwenConfig
from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
class SiglipVisionTower(nn.Module):
def __init__(self, vision_tower):
super().__init__()
self.config = SiglipVisionConfig.from_pretrained(vision_tower)
assert isinstance(self.config, SiglipVisionConfig)
self.config.num_hidden_layers -= 1 # drop the last hidden layer
self.config.vision_use_head = False
self.vision_tower = SiglipVisionModel(self.config)
self.vision_tower.requires_grad_(False)
self.image_processor = Mineru2ImageProcessor()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
)
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def image_size(self):
return self.config.image_size
def build_vision_tower(config: Mineru2QwenConfig):
vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
model_path = getattr(config, "_name_or_path", "")
if "siglip" in vision_tower.lower():
if model_path:
return SiglipVisionTower(f"{model_path}/{vision_tower}")
else:
return SiglipVisionTower(vision_tower)
raise ValueError(f"Unknown vision tower: {vision_tower}")
def build_vision_projector(config: Mineru2QwenConfig):
projector_type = getattr(config, "mm_projector_type", "linear")
if projector_type == "linear":
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU()) # type: ignore
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == "identity":
return nn.Identity()
raise ValueError(f"Unknown projector type: {projector_type}")
class Mineru2QwenModel(Qwen2Model):
config_class = Mineru2QwenConfig
def __init__(self, config: Mineru2QwenConfig):
super(Mineru2QwenModel, self).__init__(config)
self.vision_tower = build_vision_tower(config)
self.mm_projector = build_vision_projector(config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
config_class = Mineru2QwenConfig
def __init__(self, config: Mineru2QwenConfig):
super(Qwen2ForCausalLM, self).__init__(config)
config.rope_scaling = None
self.model = Mineru2QwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.ignore_index = config.ignore_index
self.image_token_index = config.image_token_index
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def encode_images(self, images: torch.Tensor):
image_features = self.get_model().vision_tower(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
):
vision_tower = self.get_model().vision_tower
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
if mm_patch_merge_type == "flat":
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_model().vision_tower.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.get_model().vision_tower.config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
raise NotImplementedError
if (
"unpad" in mm_patch_merge_type
and "anyres_max" in image_aspect_ratio
and matched_anyres_max_num_patches
):
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature, [int(h // times), int(w // times)], mode="bilinear"
)[0]
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif "unpad" in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if "unpad" in mm_patch_merge_type:
image_feature = torch.cat(
(image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
else:
image_features = self.encode_images(images)
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, self.ignore_index)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == self.image_token_index).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = (
[-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
)
)
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
)
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[List[List[int]]] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
from sglang.srt.configs.model_config import multimodal_model_archs
from sglang.srt.models.registry import ModelRegistry
try:
# sglang==0.4.5.post3
from sglang.srt.managers.multimodal_processor import (
PROCESSOR_MAPPING as PROCESSOR_MAPPING,
)
except ImportError:
# sglang==0.4.4.post1
from sglang.srt.managers.image_processor import (
IMAGE_PROCESSOR_MAPPING as PROCESSOR_MAPPING,
)
from .. import vlm_hf_model as _
from .image_processor import Mineru2ImageProcessor
from .model import Mineru2QwenForCausalLM
ModelRegistry.models[Mineru2QwenForCausalLM.__name__] = Mineru2QwenForCausalLM
PROCESSOR_MAPPING[Mineru2QwenForCausalLM] = Mineru2ImageProcessor
multimodal_model_archs.append(Mineru2QwenForCausalLM.__name__)
import asyncio
import time
from types import MethodType
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
import fastapi
from sglang.srt.entrypoints.engine import Engine as _Engine
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.tokenizer_manager import (
TokenizerManager,
dataclass_to_string_truncated,
logger,
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from ...utils.run_async import run_async
from .logit_processor import Mineru2LogitProcessor
class BatchEngine(_Engine):
"""
The engine is patched to support batch multi-modal generate, and early image preprocessing.
"""
def __init__(self, server_args: ServerArgs, **kwargs):
server_args.enable_custom_logit_processor = True
super().__init__(server_args=server_args, **kwargs)
_patch_tokenizer_manager(self.tokenizer_manager)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
# EDIT
if isinstance(image_data, list):
for _ in range(len(image_data)):
modalities_list.append(["image"])
elif image_data is not None:
modalities_list.append("image")
# ADD
if custom_logit_processor is None:
custom_logit_processor = Mineru2LogitProcessor().to_str()
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
generator = _generate_request(self.tokenizer_manager, obj, None)
if stream:
def generator_wrapper():
while True:
try:
chunk = run_async(generator.__anext__())
yield chunk
except StopAsyncIteration:
break
return generator_wrapper()
else:
ret = run_async(generator.__anext__())
return ret
async def async_generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, AsyncIterator[Dict], Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
# EDIT
if isinstance(image_data, list):
for _ in range(len(image_data)):
modalities_list.append(["image"])
elif image_data is not None:
modalities_list.append("image")
# ADD
if custom_logit_processor is None:
custom_logit_processor = Mineru2LogitProcessor().to_str()
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
generator = _generate_request(self.tokenizer_manager, obj, None)
if stream is True:
return generator
else:
return await generator.__anext__()
def _auto_create_handle_loop(self: TokenizerManager):
"""
patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
when the event loop changes.
"""
try:
curr_handle_loop = asyncio.get_running_loop()
except RuntimeError:
curr_handle_loop = None
last_handle_loop = getattr(self, "_last_handle_loop", None)
if last_handle_loop != curr_handle_loop:
self.no_create_loop = False
setattr(self, "_last_handle_loop", curr_handle_loop)
return TokenizerManager.auto_create_handle_loop(self)
def _patch_tokenizer_manager(self: TokenizerManager):
self.auto_create_handle_loop = MethodType(_auto_create_handle_loop, self)
async def _one_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request],
created_time: Optional[float],
):
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for out in self._wait_one_response(obj, request):
yield out
async def _handle_batch_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
):
batch_size = obj.batch_size
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) != 1:
raise Exception("parallel_sample_num != 1 is not supported in this patched code.")
# Send all requests
for i in range(batch_size):
tmp_obj = obj[i]
generators.append(_one_request(self, tmp_obj, request, created_time))
rids.append(tmp_obj.rid)
# Wait for all requests
is_stream = hasattr(obj, "stream") and obj.stream
if not is_stream:
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
yield outputs
else:
rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map:
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
gen = task_map.pop(task)
try:
result = task.result()
result["index"] = rid_to_index[result["meta_info"]["id"]]
yield result
new_task = asyncio.create_task(gen.__anext__())
task_map[new_task] = gen
except StopAsyncIteration:
pass
async def _generate_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
self.auto_create_handle_loop()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj.normalize_batch_and_arguments()
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}")
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, request):
yield response
else:
async for response in _handle_batch_request(self, obj, request, created_time):
yield response
import ast
import asyncio
import re
from typing import List, Optional, Union
import numpy as np
try:
# sglang==0.4.5.post3
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as BaseProcessor,
)
get_global_processor = None
except ImportError:
# sglang==0.4.4.post1
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor as BaseProcessor,
get_global_processor,
)
from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
from .model import Mineru2QwenForCausalLM
# image_best_res is only resized (not padded).
def process_anyres_image(image, processor, grid_pinpoints):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
patch_size = processor.crop_size["height"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_best_res = image.resize(best_resolution) # <<<<<<< Here changed
patches = divide_to_patches(image_best_res, processor.crop_size["height"])
image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
return np.stack(image_patches, axis=0)
class Mineru2ImageProcessor(BaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
if image_processor is None:
assert get_global_processor is not None
image_processor = get_global_processor().image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
else:
pixel_values = image_processor(image)["pixel_values"][0]
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
if hasattr(self, "cpu_executor"):
executor = self.cpu_executor
else:
executor = self.executor
if get_global_processor is not None:
image_processor = None # save ipc cost
else:
image_processor = self._processor.image_processor
if executor is not None:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
executor,
Mineru2ImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
image_processor,
)
else:
return self._process_single_image_task(
image_data,
aspect_ratio,
grid_pinpoints,
image_processor,
)
# sglang==0.4.4.post1
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", "")
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints") and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(self._process_single_image(img_data, aspect_ratio, grid_pinpoints))
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
# sglang==0.4.5.post3
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
result = await self.process_images_async(image_data, input_text, request_obj, *args, **kwargs)
if result is None:
return None
modality = Modality.IMAGE
if isinstance(request_obj.modalities, list):
if request_obj.modalities[0] == "multi-images":
modality = Modality.MULTI_IMAGES
elif request_obj.modalities[0] == "video":
modality = Modality.VIDEO
return {
"mm_items": [
MultimodalDataItem(
pixel_values=result["pixel_values"],
image_sizes=result["image_sizes"],
modality=modality,
)
],
}
ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}
from typing import List
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
class Mineru2LogitProcessor(CustomLogitProcessor):
"""
Stateless logit processor for Mineru2.
(base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
This processor applies token-level constraints to prevent repetition during generation.
It supports two main constraints:
- no_repeat_ngram_size (int):
Prevents repeating the same n-gram of specified size in the output.
Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
This implementation is slower due to its lack of specialized optimization.
- no_repeat_token_count (int):
(Placeholder for future logic)
Intended to prevent repeating the same token multiple times.
Not yet implemented in this version.
"""
def __init__(self) -> None:
super().__init__()
self._generated_ngrams = {} # Cache of generated n-grams by request ID
self._time = {} # Timestamp of the last update for each request
self._gen_step = 0 # Global generation step counter
def __call__(self, logits, batch_info: List[dict]):
"""
Applies repetition constraints to the logits before sampling tokens.
Args:
logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
- "__req__": Request object containing request ID and output_ids.
- "no_repeat_ngram_size": Size of n-gram to avoid repeating.
Returns:
FloatTensor: The modified logits tensor with banned token logits set to -inf.
"""
from sglang.srt.managers.schedule_batch import Req
self._gen_step += 1 # Update global generation step
for idx, info in enumerate(batch_info):
if not isinstance(info, dict) or "__req__" not in info:
continue
req: Req = info["__req__"]
rid = req.rid
output_ids = req.output_ids
ngram_size = info.get("no_repeat_ngram_size", 0)
# Skip if there are not enough tokens to form an n-gram
if ngram_size <= 0 or len(output_ids) < ngram_size:
continue
# Record the current step for cache cleanup tracking
self._time[rid] = self._gen_step
# Initialize n-gram cache for this request if it doesn't exist
if rid not in self._generated_ngrams:
self._generated_ngrams[rid] = {}
# Get the n-gram prefix (all but the last token)
prev_ngram = tuple(output_ids[-ngram_size:-1])
last_token = output_ids[-1]
# Store this n-gram occurrence
self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
# Get the next-token candidates to ban based on current prefix
current_prefix = tuple(output_ids[-ngram_size + 1 :])
banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
# Set the logits of banned tokens to negative infinity
for token in banned_tokens:
logits[idx][token] = -float("inf")
# Clean up cache for expired requests
expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
for rid in expired_rids:
self._generated_ngrams.pop(rid, None)
self._time.pop(rid, None)
return logits
import math
import re
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, # unpad_image, unpad_image_shape
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
from torch import nn
from transformers import (
CLIPVisionConfig,
CLIPVisionModel,
SiglipVisionConfig,
SiglipVisionModel,
)
from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
def flatten_nested_list(nested_list):
if isinstance(nested_list, list):
return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
else:
return [nested_list]
def downgrade_modality(modality):
modality_str = str(modality)
if "MULTI_IMAGES" in modality_str:
return "multi-images"
if "IMAGE" in modality_str:
return "image"
if "VIDEO" in modality_str:
return "video"
if "AUDIO" in modality_str:
return "audio"
raise ValueError(f"Unexpected modality: {modality_str}")
class Mineru2QwenForCausalLM(nn.Module):
def __init__(
self,
config: Mineru2QwenConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 151646
# load vision tower
mm_vision_tower = self.config.mm_vision_tower
if "clip" in mm_vision_tower:
vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
self.vision_tower = CLIPVisionModel(vision_config) # type: ignore
elif "siglip" in mm_vision_tower:
vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
self.vision_tower = SiglipVisionModel(vision_config) # type: ignore
# Siglip needs all feature tokens
self.config.mm_vision_select_feature = "full"
else:
raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")
### EDIT: change projector
# the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
self.multi_modal_mlp = build_vision_projector(config)
self.language_model = Qwen2ForCausalLM(
config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))
language_model_device = next(self.language_model.parameters()).device
self.vision_tower = self.vision_tower.to(language_model_device)
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
if self.vision_feature_select_strategy in ("patch", "full"):
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
def pad_input_ids(self, input_ids: List[int], image_inputs):
if hasattr(image_inputs, "mm_items"): # MultimodalInputs
# sglang==0.4.5.post3
image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
pad_values = [item.pad_value for item in image_inputs.mm_items]
else: # ImageInputs
# sglang==0.4.4.post1
image_sizes = image_inputs.image_sizes
pad_values = image_inputs.pad_values
# hardcode for spatial_unpad + anyres
# if image_inputs.modalities is not None and (
# "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
# ):
# image_aspect_ratio = "pad"
# else:
# image_aspect_ratio = "anyres"
offset_list = []
image_inputs.image_pad_len = []
for image_idx, image_s in enumerate(image_sizes):
if len(image_sizes) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
else:
new_image_feature_len = self.image_feature_len # multiimage
height = width = self.num_patches_per_side
if "anyres" in self.config.image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_s,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
h = num_patch_height * height
w = num_patch_width * width
### EDIT: remove `unpad_image_shape`
# new_h, new_w = unpad_image_shape(h, w, image_s)
new_h, new_w = h, w
if "anyres_max" in self.config.image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
if times > 1.1:
new_h = int(new_h // times)
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)
try:
offset = input_ids.index(self.config.image_token_index)
except ValueError:
offset = 0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
offset_list.append(offset)
image_inputs.image_pad_len.append(new_image_feature_len)
image_inputs.image_offsets = offset_list
return input_ids
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
image_features = self.multi_modal_mlp(selected_image_feature)
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hasattr(forward_batch, "mm_inputs"):
# sglang==0.4.5.post3
image_inputs = forward_batch.mm_inputs
is_sglang_mm_inputs = True
else:
# sglang==0.4.4.post1
image_inputs = forward_batch.image_inputs
is_sglang_mm_inputs = False
if image_inputs is None:
image_inputs = []
if forward_batch.forward_mode.is_extend():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
max_image_offset = []
for im in image_inputs:
if im:
if hasattr(im, "mm_items"):
# sglang==0.4.5.post3
modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
elif im.modalities is not None:
# sglang==0.4.4.post1
modalities_list.extend(im.modalities)
if im and im.image_offsets:
max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
else:
max_image_offset.append(-1)
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
bs = forward_batch.batch_size
if is_sglang_mm_inputs:
# sglang==0.4.5.post3
pixel_values = flatten_nested_list(
[[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
) # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
image_sizes = [
flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
for i in range(bs)
if need_vision[i]
] # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
else:
# sglang==0.4.4.post1
pixel_values = [image_inputs[i].pixel_values for i in range(bs) if need_vision[i]]
image_sizes = [image_inputs[i].image_sizes for i in range(bs) if need_vision[i]]
########## Encode Image ########
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096
if self.mm_patch_merge_type.startswith("spatial"):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if modalities_list[image_idx] == "image":
image_aspect_ratio = self.config.image_aspect_ratio # single image
elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = (
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
# )
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image"
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
vision_tower_image_size = self.image_size
try:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx][0],
self.config.image_grid_pinpoints,
vision_tower_image_size,
)
except Exception as e:
print(f"Error: {e}")
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
image_feature = image_feature.view(2, 2, height, width, -1)
if "unpad" in self.mm_patch_merge_type:
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
### EDIT: remove `unpad_image`
# image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature,
[int(h // times), int(w // times)],
mode="bilinear",
)[0]
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
image_feature = image_feature.unsqueeze(0)
else:
if modalities_list[image_idx] == "video": # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(num_of_frames, height, width, -1)
image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # N, C, H, W
height, weight = image_feature.shape[2:]
scaled_shape = [
math.ceil(height / 2),
math.ceil(weight / 2),
]
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
image_feature = image_feature.flatten(2).transpose(1, 2).contiguous() # N, C, H*W
if "unpad" in self.mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
# Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
self.language_model.model.image_newline[None, None].expand(
image_feature.shape[0],
1,
image_feature.shape[-1],
),
),
dim=1,
)
new_image_features.append(image_feature)
image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
seq_len = extend_seq_lens[i]
prefix_len = prefix_lens_cpu[i]
# Multiple images
for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
continue
if image_offset >= prefix_len + seq_len:
break
tmp_image_feature = image_features[pt][image_idx]
pad_len = tmp_image_feature.shape[0]
input_offset = image_offset - prefix_len
left_idx = start_idx + input_offset
right_idx = left_idx + pad_len
assert right_idx > start_idx
if input_offset < 0:
left_idx = start_idx
tmp_image_feature = tmp_image_feature[-input_offset:]
if right_idx > start_idx + seq_len:
tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
right_idx = start_idx + seq_len
try:
input_embeds[left_idx:right_idx] = tmp_image_feature
except RuntimeError as e:
print(f"RuntimeError in image encoding: {e}")
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
pt += 1
return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
else:
raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
projector_weights = {
"model.mm_projector": "multi_modal_mlp",
"model.vision_tower.vision_tower": "vision_tower",
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "projector" in name or "vision_tower" in name or "image_newline" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
else:
self.language_model.load_weights([(name, loaded_weight)])
@property
def num_patches_per_side(self):
return self.image_size // self.patch_size
EntryClass = [Mineru2QwenForCausalLM]
import os
import sys
from fastapi import Request
from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from .logit_processor import Mineru2LogitProcessor
_custom_logit_processor_str = Mineru2LogitProcessor().to_str()
# remote the existing /generate route
for route in app.routes[:]:
if hasattr(route, "path") and getattr(route, "path") == "/generate":
app.routes.remove(route)
# add the custom /generate route
@app.api_route("/generate", methods=["POST", "PUT"])
async def custom_generate_request(obj: GenerateReqInput, request: Request):
if obj.custom_logit_processor is None:
obj.custom_logit_processor = _custom_logit_processor_str
return await generate_request(obj, request)
def main():
server_args = prepare_server_args(sys.argv[1:])
if server_args.chat_template is None:
server_args.chat_template = "chatml"
server_args.enable_custom_logit_processor = True
if server_args.model_path is None:
server_args.model_path = auto_download_and_get_model_root_path("/","vlm")
try:
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
if __name__ == "__main__":
main()
# Copyright (c) Opendatalab. All rights reserved.
from magic_pdf.config.ocr_content_type import BlockType # Copyright (c) Opendatalab. All rights reserved.
from magic_pdf.libs.boxbase import ( from mineru.utils.boxbase import (
calculate_iou, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio, calculate_overlap_area_in_bbox1_area_ratio,
calculate_vertical_projection_overlap_ratio, calculate_vertical_projection_overlap_ratio,
get_minbox_if_overlap_by_ratio get_minbox_if_overlap_by_ratio
) )
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block from mineru.utils.enum_class import BlockType
def add_bboxes(blocks, block_type, bboxes): def process_groups(groups, body_key, caption_key, footnote_key):
for block in blocks: body_blocks = []
x0, y0, x1, y1 = block['bbox'] caption_blocks = []
if block_type in [ footnote_blocks = []
BlockType.ImageBody, maybe_text_image_blocks = []
BlockType.ImageCaption, for i, group in enumerate(groups):
BlockType.ImageFootnote, if body_key == 'image_body' and len(group[caption_key]) == 0 and len(group[footnote_key]) == 0:
BlockType.TableBody, # 如果没有caption和footnote,则不需要将group_id添加到image_body中
BlockType.TableCaption, group[body_key]['group_id'] = i
BlockType.TableFootnote, maybe_text_image_blocks.append(group[body_key])
]: continue
bboxes.append(
[
x0,
y0,
x1,
y1,
None,
None,
None,
block_type,
None,
None,
None,
None,
block['score'],
block['group_id'],
]
)
else: else:
bboxes.append( group[body_key]['group_id'] = i
[ body_blocks.append(group[body_key])
x0, for caption_block in group[caption_key]:
y0, caption_block['group_id'] = i
x1, caption_blocks.append(caption_block)
y1, for footnote_block in group[footnote_key]:
None, footnote_block['group_id'] = i
None, footnote_blocks.append(footnote_block)
None, return body_blocks, caption_blocks, footnote_blocks, maybe_text_image_blocks
block_type,
None,
None, def prepare_block_bboxes(
None,
None,
block['score'],
]
)
def ocr_prepare_bboxes_for_layout_split_v2(
img_body_blocks, img_body_blocks,
img_caption_blocks, img_caption_blocks,
img_footnote_blocks, img_footnote_blocks,
...@@ -73,15 +47,15 @@ def ocr_prepare_bboxes_for_layout_split_v2( ...@@ -73,15 +47,15 @@ def ocr_prepare_bboxes_for_layout_split_v2(
): ):
all_bboxes = [] all_bboxes = []
add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes) add_bboxes(img_body_blocks, BlockType.IMAGE_BODY, all_bboxes)
add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes) add_bboxes(img_caption_blocks, BlockType.IMAGE_CAPTION, all_bboxes)
add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes) add_bboxes(img_footnote_blocks, BlockType.IMAGE_CAPTION, all_bboxes)
add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes) add_bboxes(table_body_blocks, BlockType.TABLE_BODY, all_bboxes)
add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes) add_bboxes(table_caption_blocks, BlockType.TABLE_CAPTION, all_bboxes)
add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes) add_bboxes(table_footnote_blocks, BlockType.TABLE_FOOTNOTE, all_bboxes)
add_bboxes(text_blocks, BlockType.Text, all_bboxes) add_bboxes(text_blocks, BlockType.TEXT, all_bboxes)
add_bboxes(title_blocks, BlockType.Title, all_bboxes) add_bboxes(title_blocks, BlockType.TITLE, all_bboxes)
add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes) add_bboxes(interline_equation_blocks, BlockType.INTERLINE_EQUATION, all_bboxes)
"""block嵌套问题解决""" """block嵌套问题解决"""
"""文本框与标题框重叠,优先信任文本框""" """文本框与标题框重叠,优先信任文本框"""
...@@ -97,7 +71,7 @@ def ocr_prepare_bboxes_for_layout_split_v2( ...@@ -97,7 +71,7 @@ def ocr_prepare_bboxes_for_layout_split_v2(
"""discarded_blocks""" """discarded_blocks"""
all_discarded_blocks = [] all_discarded_blocks = []
add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks) add_bboxes(discarded_blocks, BlockType.DISCARDED, all_discarded_blocks)
"""footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半30%区域的""" """footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半30%区域的"""
footnote_blocks = [] footnote_blocks = []
...@@ -122,63 +96,31 @@ def ocr_prepare_bboxes_for_layout_split_v2( ...@@ -122,63 +96,31 @@ def ocr_prepare_bboxes_for_layout_split_v2(
return all_bboxes, all_discarded_blocks, footnote_blocks return all_bboxes, all_discarded_blocks, footnote_blocks
def find_blocks_under_footnote(all_bboxes, footnote_blocks): def add_bboxes(blocks, block_type, bboxes):
need_remove_blocks = [] for block in blocks:
for block in all_bboxes: x0, y0, x1, y1 = block['bbox']
block_x0, block_y0, block_x1, block_y1 = block[:4] if block_type in [
for footnote_bbox in footnote_blocks: BlockType.IMAGE_BODY,
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox BlockType.IMAGE_CAPTION,
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1 BlockType.IMAGE_FOOTNOTE,
if ( BlockType.TABLE_BODY,
block_y0 >= footnote_y1 BlockType.TABLE_CAPTION,
and calculate_vertical_projection_overlap_ratio( BlockType.TABLE_FOOTNOTE,
(block_x0, block_y0, block_x1, block_y1), footnote_bbox ]:
) bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block['score'], block['group_id']])
>= 0.8 else:
): bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block['score']])
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block
text_blocks = []
for block in all_bboxes:
if block[7] == BlockType.Text:
text_blocks.append(block)
interline_equation_blocks = []
for block in all_bboxes:
if block[7] == BlockType.InterlineEquation:
interline_equation_blocks.append(block)
need_remove = []
for interline_equation_block in interline_equation_blocks:
for text_block in text_blocks:
interline_equation_block_bbox = interline_equation_block[:4]
text_block_bbox = text_block[:4]
if calculate_iou(interline_equation_block_bbox, text_block_bbox) > 0.8:
if text_block not in need_remove:
need_remove.append(text_block)
if len(need_remove) > 0:
for block in need_remove:
all_bboxes.remove(block)
return all_bboxes
def fix_text_overlap_title_blocks(all_bboxes): def fix_text_overlap_title_blocks(all_bboxes):
# 先提取所有text和title block # 先提取所有text和title block
text_blocks = [] text_blocks = []
for block in all_bboxes: for block in all_bboxes:
if block[7] == BlockType.Text: if block[7] == BlockType.TEXT:
text_blocks.append(block) text_blocks.append(block)
title_blocks = [] title_blocks = []
for block in all_bboxes: for block in all_bboxes:
if block[7] == BlockType.Title: if block[7] == BlockType.TITLE:
title_blocks.append(block) title_blocks.append(block)
need_remove = [] need_remove = []
...@@ -219,6 +161,54 @@ def remove_need_drop_blocks(all_bboxes, discarded_blocks): ...@@ -219,6 +161,54 @@ def remove_need_drop_blocks(all_bboxes, discarded_blocks):
return all_bboxes return all_bboxes
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block
text_blocks = []
for block in all_bboxes:
if block[7] == BlockType.TEXT:
text_blocks.append(block)
interline_equation_blocks = []
for block in all_bboxes:
if block[7] == BlockType.INTERLINE_EQUATION:
interline_equation_blocks.append(block)
need_remove = []
for interline_equation_block in interline_equation_blocks:
for text_block in text_blocks:
interline_equation_block_bbox = interline_equation_block[:4]
text_block_bbox = text_block[:4]
if calculate_iou(interline_equation_block_bbox, text_block_bbox) > 0.8:
if text_block not in need_remove:
need_remove.append(text_block)
if len(need_remove) > 0:
for block in need_remove:
all_bboxes.remove(block)
return all_bboxes
def find_blocks_under_footnote(all_bboxes, footnote_blocks):
need_remove_blocks = []
for block in all_bboxes:
block_x0, block_y0, block_x1, block_y1 = block[:4]
for footnote_bbox in footnote_blocks:
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if (
block_y0 >= footnote_y1
and calculate_vertical_projection_overlap_ratio(
(block_x0, block_y0, block_x1, block_y1), footnote_bbox
)
>= 0.8
):
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def remove_overlaps_min_blocks(all_bboxes): def remove_overlaps_min_blocks(all_bboxes):
# 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。 # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
# 删除重叠blocks中较小的那些 # 删除重叠blocks中较小的那些
...@@ -254,4 +244,4 @@ def remove_overlaps_min_blocks(all_bboxes): ...@@ -254,4 +244,4 @@ def remove_overlaps_min_blocks(all_bboxes):
for block in need_remove: for block in need_remove:
all_bboxes.remove(block) all_bboxes.remove(block)
return all_bboxes return all_bboxes
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import copy
import os
import statistics
import warnings
from typing import List
import torch
from loguru import logger
from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import BlockType, ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
"""获取所有line并计算正文line的高度"""
line_height = get_line_height(blocks)
"""获取所有line并对line排序"""
sorted_bboxes = sort_lines_by_model(blocks, page_w, page_h, line_height, footnote_blocks)
"""根据line的中位数算block的序列关系"""
blocks = cal_block_index(blocks, sorted_bboxes)
"""将image和table的block还原回group形式参与后续流程"""
blocks = revert_group_blocks(blocks)
"""重排block"""
sorted_blocks = sorted(blocks, key=lambda b: b['index'])
"""block内重排(img和table的block内多个caption或footnote的排序)"""
for block in sorted_blocks:
if block['type'] in [BlockType.IMAGE, BlockType.TABLE]:
block['blocks'] = sorted(block['blocks'], key=lambda b: b['index'])
return sorted_blocks
def get_line_height(blocks):
page_line_height_list = []
for block in blocks:
if block['type'] in [
BlockType.TEXT, BlockType.TITLE,
BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
]:
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
return statistics.median(page_line_height_list)
else:
return 10
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height, footnote_blocks):
page_line_list = []
def add_lines_to_block(b):
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
b['lines'] = []
for line_bbox in line_bboxes:
b['lines'].append({'bbox': line_bbox, 'spans': []})
page_line_list.extend(line_bboxes)
for block in fix_blocks:
if block['type'] in [
BlockType.TEXT, BlockType.TITLE,
BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
]:
if len(block['lines']) == 0:
add_lines_to_block(block)
elif block['type'] in [BlockType.TITLE] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
block['real_lines'] = copy.deepcopy(block['lines'])
add_lines_to_block(block)
else:
for line in block['lines']:
bbox = line['bbox']
page_line_list.append(bbox)
elif block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
block['real_lines'] = copy.deepcopy(block['lines'])
add_lines_to_block(block)
for block in footnote_blocks:
footnote_block = {'bbox': block[:4]}
add_lines_to_block(footnote_block)
if len(page_line_list) > 200: # layoutreader最高支持512line
return None
# 使用layoutreader排序
x_scale = 1000.0 / page_w
y_scale = 1000.0 / page_h
boxes = []
# logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
for left, top, right, bottom in page_line_list:
if left < 0:
logger.warning(
f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
left = 0
if right > page_w:
logger.warning(
f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
right = page_w
if top < 0:
logger.warning(
f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
top = 0
if bottom > page_h:
logger.warning(
f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
bottom = page_h
left = round(left * x_scale)
top = round(top * y_scale)
right = round(right * x_scale)
bottom = round(bottom * y_scale)
assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}' # noqa: E126, E121
boxes.append([left, top, right, bottom])
model_manager = ModelSingleton()
model = model_manager.get_model('layoutreader')
with torch.no_grad():
orders = do_predict(boxes, model)
sorted_bboxes = [page_line_list[i] for i in orders]
return sorted_bboxes
def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
# block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
x0, y0, x1, y1 = block_bbox
block_height = y1 - y0
block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox
if line_height * 2 < block_height:
if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
lines = int(block_height / line_height)
else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4:
lines = 3
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height)
else: # 判断长宽比
if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行
lines = 2
line_height = (y1 - y0) / lines
# 确定从哪个y位置开始绘制线条
current_y = y0
# 用于存储线条的位置信息[(x0, y), ...]
lines_positions = []
for i in range(lines):
lines_positions.append([x0, current_y, x1, current_y + line_height])
current_y += line_height
return lines_positions
else:
return [[x0, y0, x1, y1]]
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
device_name = get_device()
bf_16_support = False
if device_name.startswith("cuda"):
bf_16_support = torch.cuda.is_bf16_supported()
elif device_name.startswith("mps"):
bf_16_support = True
device = torch.device(device_name)
if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.layout_reader), ModelPath.layout_reader)
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
)
else:
logger.warning(
'local layoutreader model not exists, use online model from huggingface'
)
model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader'
)
if bf_16_support:
model.to(device).eval().bfloat16()
else:
model.to(device).eval()
else:
logger.error('model name not allow')
exit(1)
return model
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, model_name: str):
if model_name not in self._models:
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]
def do_predict(boxes: List[List[int]], model) -> List[int]:
from mineru.model.reading_order.layout_reader import (
boxes2inputs, parse_logits, prepare_inputs)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0)
return parse_logits(logits, len(boxes))
def cal_block_index(fix_blocks, sorted_bboxes):
if sorted_bboxes is not None:
# 使用layoutreader排序
for block in fix_blocks:
line_index_list = []
if len(block['lines']) == 0:
block['index'] = sorted_bboxes.index(block['bbox'])
else:
for line in block['lines']:
line['index'] = sorted_bboxes.index(line['bbox'])
line_index_list.append(line['index'])
median_value = statistics.median(line_index_list)
block['index'] = median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
if 'real_lines' in block:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
else:
# 使用xycut排序
block_bboxes = []
for block in fix_blocks:
# 如果block['bbox']任意值小于0,将其置为0
block['bbox'] = [max(0, x) for x in block['bbox']]
block_bboxes.append(block['bbox'])
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
if 'real_lines' in block:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
import numpy as np
from mineru.model.reading_order.xycut import recursive_xy_cut
random_boxes = np.array(block_bboxes)
np.random.shuffle(random_boxes)
res = []
recursive_xy_cut(np.asarray(random_boxes).astype(int), np.arange(len(block_bboxes)), res)
assert len(res) == len(block_bboxes)
sorted_boxes = random_boxes[np.array(res)].tolist()
for i, block in enumerate(fix_blocks):
block['index'] = sorted_boxes.index(block['bbox'])
# 生成line index
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
line_inedx = 1
for block in sorted_blocks:
for line in block['lines']:
line['index'] = line_inedx
line_inedx += 1
return fix_blocks
def revert_group_blocks(blocks):
image_groups = {}
table_groups = {}
new_blocks = []
for block in blocks:
if block['type'] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
group_id = block['group_id']
if group_id not in image_groups:
image_groups[group_id] = []
image_groups[group_id].append(block)
elif block['type'] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
group_id = block['group_id']
if group_id not in table_groups:
table_groups[group_id] = []
table_groups[group_id].append(block)
else:
new_blocks.append(block)
for group_id, blocks in image_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.IMAGE_BODY, BlockType.IMAGE))
for group_id, blocks in table_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.TABLE_BODY, BlockType.TABLE))
return new_blocks
def process_block_list(blocks, body_type, block_type):
indices = [block['index'] for block in blocks]
median_index = statistics.median(indices)
body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
return {
'type': block_type,
'bbox': body_bbox,
'blocks': blocks,
'index': median_index,
}
\ No newline at end of file
import math import math
def _is_in_or_part_overlap(box1, box2) -> bool: def is_in(box1, box2) -> bool:
"""两个bbox是否有部分重叠或者包含."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return not (x1_1 < x0_2 or # box1在box2的左边
x0_1 > x1_2 or # box1在box2的右边
y1_1 < y0_2 or # box1在box2的上边
y0_1 > y1_2) # box1在box2的下边
def _is_in_or_part_overlap_with_area_ratio(box1,
box2,
area_ratio_threshold=0.6):
"""判断box1是否在box2里面,或者box1和box2有部分重叠,且重叠面积占box1的比例超过area_ratio_threshold."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
if not _is_in_or_part_overlap(box1, box2):
return False
# 计算重叠面积
x_left = max(x0_1, x0_2)
y_top = max(y0_1, y0_2)
x_right = min(x1_1, x1_2)
y_bottom = min(y1_1, y1_2)
overlap_area = (x_right - x_left) * (y_bottom - y_top)
# 计算box1的面积
box1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
return overlap_area / box1_area > area_ratio_threshold
def _is_in(box1, box2) -> bool:
"""box1是否完全在box2里面.""" """box1是否完全在box2里面."""
x0_1, y0_1, x1_1, y1_1 = box1 x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2 x0_2, y0_2, x1_2, y1_2 = box2
return (x0_1 >= x0_2 and # box1的左边界不在box2的左边外 return (
y0_1 >= y0_2 and # box1的上边界不在box2的上边外 x0_1 >= x0_2 # box1的左边界不在box2的左边外
x1_1 <= x1_2 and # box1的右边界不在box2的右边外 and y0_1 >= y0_2 # box1的上边界不在box2的上边外
y1_1 <= y1_2) # box1的下边界不在box2的下边外 and x1_1 <= x1_2 # box1的右边界不在box2的右边外
and y1_1 <= y1_2
) # box1的下边界不在box2的下边外
def _is_part_overlap(box1, box2) -> bool:
"""两个bbox是否有部分重叠,但不完全包含."""
if box1 is None or box2 is None:
return False
return _is_in_or_part_overlap(box1, box2) and not _is_in(box1, box2)
def _left_intersect(left_box, right_box):
"""检查两个box的左边界是否有交集,也就是left_box的右边界是否在right_box的左边界内."""
if left_box is None or right_box is None:
return False
x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box
return x1_1 > x0_2 and x0_1 < x0_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def bbox_relative_pos(bbox1, bbox2):
"""判断两个矩形框的相对位置关系.
def _right_intersect(left_box, right_box): Args:
"""检查box是否在右侧边界有交集,也就是left_box的左边界是否在right_box的右边界内.""" bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
if left_box is None or right_box is None: bbox2: 一个四元组,表示第二个矩形框的左上角和右下角的坐标,格式为(x2, y2, x2b, y2b)
return False
x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box
return x0_1 < x1_2 and x1_1 > x1_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def _is_vertical_full_overlap(box1, box2, x_torlence=2):
"""x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含 y方向上:box1和box2有重叠."""
# 解析box的坐标
x11, y11, x12, y12 = box1 # 左上角和右下角的坐标 (x1, y1, x2, y2)
x21, y21, x22, y22 = box2
# 在x轴方向上,box1是否包含box2 或 box2包含box1
contains_in_x = (x11 - x_torlence <= x21 and x12 + x_torlence >= x22) or (
x21 - x_torlence <= x11 and x22 + x_torlence >= x12)
# 在y轴方向上,box1和box2是否有重叠
overlap_in_y = not (y12 < y21 or y11 > y22)
return contains_in_x and overlap_in_y Returns:
一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
"""
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left = x2b < x1
right = x1b < x2
bottom = y2b < y1
top = y1b < y2
return left, right, bottom, top
def _is_bottom_full_overlap(box1, box2, y_tolerance=2):
"""检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制 这个函数和_is_vertical-
full_overlap的区别是,这个函数允许box1和box2在x方向上有轻微的重叠,允许一定的模糊度."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1 def bbox_distance(bbox1, bbox2):
x0_2, y0_2, x1_2, y1_2 = box2 """计算两个矩形框的距离。
tolerance_margin = 2
is_xdir_full_overlap = (
(x0_1 - tolerance_margin <= x0_2 <= x1_1 + tolerance_margin
and x0_1 - tolerance_margin <= x1_2 <= x1_1 + tolerance_margin)
or (x0_2 - tolerance_margin <= x0_1 <= x1_2 + tolerance_margin
and x0_2 - tolerance_margin <= x1_1 <= x1_2 + tolerance_margin))
return y0_2 < y1_1 and 0 < (y1_1 - Args:
y0_2) < y_tolerance and is_xdir_full_overlap bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (tuple): 第二个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
Returns:
float: 矩形框之间的距离。
"""
def _is_left_overlap( def dist(point1, point2):
box1, return math.sqrt((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2)
box2,
):
"""检查box1的左侧是否和box2有重叠 在Y方向上可以是部分重叠或者是完全重叠。不分box1和box2的上下关系,也就是无论box1在box2下
方还是box2在box1下方,都可以检测到重叠。 X方向上."""
def __overlap_y(Ay1, Ay2, By1, By2): x1, y1, x1b, y1b = bbox1
return max(0, min(Ay2, By2) - max(Ay1, By1)) x2, y2, x2b, y2b = bbox2
if box1 is None or box2 is None: left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
return False
x0_1, y0_1, x1_1, y1_1 = box1 if top and left:
x0_2, y0_2, x1_2, y1_2 = box2 return dist((x1, y1b), (x2b, y2))
elif left and bottom:
return dist((x1, y1), (x2b, y2b))
elif bottom and right:
return dist((x1b, y1), (x2, y2b))
elif right and top:
return dist((x1b, y1b), (x2, y2))
elif left:
return x1 - x2b
elif right:
return x2 - x1b
elif bottom:
return y1 - y2b
elif top:
return y2 - y1b
return 0.0
y_overlap_len = __overlap_y(y0_1, y1_1, y0_2, y1_2)
ratio_1 = 1.0 * y_overlap_len / (y1_1 - y0_1) if y1_1 - y0_1 != 0 else 0
ratio_2 = 1.0 * y_overlap_len / (y1_2 - y0_2) if y1_2 - y0_2 != 0 else 0
vertical_overlap_cond = ratio_1 >= 0.5 or ratio_2 >= 0.5
# vertical_overlap_cond = y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1 or y0_2<=y0_1<=y1_2 or y0_2<=y1_1<=y1_2 def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
return x0_1 <= x0_2 <= x1_1 and vertical_overlap_cond """通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
if overlap_ratio > ratio:
if area1 <= area2:
return bbox1
else:
return bbox2
else:
return None
def __is_overlaps_y_exceeds_threshold(bbox1, def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
bbox2, """计算box1和box2的重叠面积占最小面积的box的比例."""
overlap_ratio_threshold=0.8): # Determine the coordinates of the intersection rectangle
"""检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%""" x_left = max(bbox1[0], bbox2[0])
_, y0_1, _, y1_1 = bbox1 y_top = max(bbox1[1], bbox2[1])
_, y0_2, _, y1_2 = bbox2 x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2)) if x_right < x_left or y_bottom < y_top:
height1, height2 = y1_1 - y0_1, y1_2 - y0_2 return 0.0
# max_height = max(height1, height2)
min_height = min(height1, height2)
return (overlap / min_height) > overlap_ratio_threshold # The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
(bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0
else:
return intersection_area / min_box_area
def calculate_iou(bbox1, bbox2): def calculate_iou(bbox1, bbox2):
...@@ -195,27 +148,6 @@ def calculate_iou(bbox1, bbox2): ...@@ -195,27 +148,6 @@ def calculate_iou(bbox1, bbox2):
return iou return iou
def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占最小面积的box的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
(bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0
else:
return intersection_area / min_box_area
def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2): def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例.""" """计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle # Determine the coordinates of the intersection rectangle
...@@ -236,220 +168,6 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2): ...@@ -236,220 +168,6 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
return intersection_area / bbox1_area return intersection_area / bbox1_area
def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
"""通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
if overlap_ratio > ratio:
if area1 <= area2:
return bbox1
else:
return bbox2
else:
return None
def get_bbox_in_boundary(bboxes: list, boundary: tuple) -> list:
x0, y0, x1, y1 = boundary
new_boxes = [
box for box in bboxes
if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1
]
return new_boxes
def is_vbox_on_side(bbox, width, height, side_threshold=0.2):
"""判断一个bbox是否在pdf页面的边缘."""
x0, x1 = bbox[0], bbox[2]
if x1 <= width * side_threshold or x0 >= width * (1 - side_threshold):
return True
return False
def find_top_nearest_text_bbox(pymu_blocks, obj_bbox):
tolerance_margin = 4
top_boxes = [
box for box in pymu_blocks
if obj_bbox[1] - box['bbox'][3] >= -tolerance_margin
and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
top_boxes = [
box for box in top_boxes if any([
obj_bbox[0] - tolerance_margin <= box['bbox'][0] <= obj_bbox[2] +
tolerance_margin, obj_bbox[0] -
tolerance_margin <= box['bbox'][2] <= obj_bbox[2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[0] <= box['bbox'][2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[2] <= box['bbox'][2] +
tolerance_margin
])
]
# 然后找到y1最大的那个
if len(top_boxes) > 0:
top_boxes.sort(key=lambda x: x['bbox'][3], reverse=True)
return top_boxes[0]
else:
return None
def find_bottom_nearest_text_bbox(pymu_blocks, obj_bbox):
bottom_boxes = [
box for box in pymu_blocks if box['bbox'][1] -
obj_bbox[3] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
bottom_boxes = [
box for box in bottom_boxes if any([
obj_bbox[0] - 2 <= box['bbox'][0] <= obj_bbox[2] + 2, obj_bbox[0] -
2 <= box['bbox'][2] <= obj_bbox[2] + 2, box['bbox'][0] -
2 <= obj_bbox[0] <= box['bbox'][2] + 2, box['bbox'][0] -
2 <= obj_bbox[2] <= box['bbox'][2] + 2
])
]
# 然后找到y0最小的那个
if len(bottom_boxes) > 0:
bottom_boxes.sort(key=lambda x: x['bbox'][1], reverse=False)
return bottom_boxes[0]
else:
return None
def find_left_nearest_text_bbox(pymu_blocks, obj_bbox):
"""寻找左侧最近的文本block."""
left_boxes = [
box for box in pymu_blocks if obj_bbox[0] -
box['bbox'][2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
left_boxes = [
box for box in left_boxes if any([
obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x1最大的那个
if len(left_boxes) > 0:
left_boxes.sort(key=lambda x: x['bbox'][2], reverse=True)
return left_boxes[0]
else:
return None
def find_right_nearest_text_bbox(pymu_blocks, obj_bbox):
"""寻找右侧最近的文本block."""
right_boxes = [
box for box in pymu_blocks if box['bbox'][0] -
obj_bbox[2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
right_boxes = [
box for box in right_boxes if any([
obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x0最小的那个
if len(right_boxes) > 0:
right_boxes.sort(key=lambda x: x['bbox'][0], reverse=False)
return right_boxes[0]
else:
return None
def bbox_relative_pos(bbox1, bbox2):
"""判断两个矩形框的相对位置关系.
Args:
bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
bbox2: 一个四元组,表示第二个矩形框的左上角和右下角的坐标,格式为(x2, y2, x2b, y2b)
Returns:
一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
"""
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left = x2b < x1
right = x1b < x2
bottom = y2b < y1
top = y1b < y2
return left, right, bottom, top
def bbox_distance(bbox1, bbox2):
"""计算两个矩形框的距离。
Args:
bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (tuple): 第二个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
Returns:
float: 矩形框之间的距离。
"""
def dist(point1, point2):
return math.sqrt((point1[0] - point2[0])**2 +
(point1[1] - point2[1])**2)
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
if top and left:
return dist((x1, y1b), (x2b, y2))
elif left and bottom:
return dist((x1, y1), (x2b, y2b))
elif bottom and right:
return dist((x1b, y1), (x2, y2b))
elif right and top:
return dist((x1b, y1b), (x2, y2))
elif left:
return x1 - x2b
elif right:
return x2 - x1b
elif bottom:
return y1 - y2b
elif top:
return y2 - y1b
return 0.0
def box_area(bbox):
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
def get_overlap_area(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
def calculate_vertical_projection_overlap_ratio(block1, block2): def calculate_vertical_projection_overlap_ratio(block1, block2):
""" """
Calculate the proportion of the x-axis covered by the vertical projection of two blocks. Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
...@@ -482,4 +200,4 @@ def calculate_vertical_projection_overlap_ratio(block1, block2): ...@@ -482,4 +200,4 @@ def calculate_vertical_projection_overlap_ratio(block1, block2):
# Proportion of the x-axis covered by the intersection # Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}") # logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length return intersection_length / block1_length
\ No newline at end of file
"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组.""" # Copyright (c) Opendatalab. All rights reserved.
import json import json
import os import os
import torch
from loguru import logger from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量 # 定义配置文件名常量
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json') CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
def read_config(): def read_config():
...@@ -20,11 +17,12 @@ def read_config(): ...@@ -20,11 +17,12 @@ def read_config():
config_file = os.path.join(home_dir, CONFIG_FILE_NAME) config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file): if not os.path.exists(config_file):
raise FileNotFoundError(f'{config_file} not found') logger.warning(f'{config_file} not found, using default configuration')
return None
with open(config_file, 'r', encoding='utf-8') as f: else:
config = json.load(f) with open(config_file, 'r', encoding='utf-8') as f:
return config config = json.load(f)
return config
def get_s3_config(bucket_name: str): def get_s3_config(bucket_name: str):
...@@ -55,85 +53,73 @@ def get_bucket_name(path): ...@@ -55,85 +53,73 @@ def get_bucket_name(path):
return bucket return bucket
def get_local_models_dir(): def parse_bucket_key(s3_full_path: str):
config = read_config() """
models_dir = config.get('models-dir') 输入 s3://bucket/path/to/my/file.txt
if models_dir is None: 输出 bucket, path/to/my/file.txt
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default") """
return '/tmp/models' s3_full_path = s3_full_path.strip()
else: if s3_full_path.startswith("s3://"):
return models_dir s3_full_path = s3_full_path[5:]
if s3_full_path.startswith("/"):
s3_full_path = s3_full_path[1:]
def get_local_layoutreader_model_dir(): bucket, key = s3_full_path.split("/", 1)
config = read_config() return bucket, key
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device(): def get_device():
config = read_config() device_mode = os.getenv('MINERU_DEVICE_MODE', None)
device = config.get('device-mode') if device_mode is not None:
if device is None: return device_mode
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return 'cpu'
else: else:
return device if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def get_table_recog_config(): def get_table_recog_config():
config = read_config() table_enable = os.getenv('MINERU_TABLE_ENABLE', None)
table_config = config.get('table-config') if table_enable is not None:
if table_config is None: return json.loads(f'{{"enable": {table_enable}}}')
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
else: else:
return table_config # logger.warning(f"not found 'MINERU_TABLE_ENABLE' in environment variable, use 'true' as default.")
return json.loads(f'{{"enable": true}}')
def get_layout_config(): def get_formula_config():
config = read_config() formula_enable = os.getenv('MINERU_FORMULA_ENABLE', None)
layout_config = config.get('layout-config') if formula_enable is not None:
if layout_config is None: return json.loads(f'{{"enable": {formula_enable}}}')
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
else: else:
return layout_config # logger.warning(f"not found 'MINERU_FORMULA_ENABLE' in environment variable, use 'true' as default.")
return json.loads(f'{{"enable": true}}')
def get_formula_config(): def get_latex_delimiter_config():
config = read_config() config = read_config()
formula_config = config.get('formula-config') latex_delimiter_config = config.get('latex-delimiter-config')
if formula_config is None: if latex_delimiter_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default") # logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}') return None
else: else:
return formula_config return latex_delimiter_config
def get_llm_aided_config(): def get_llm_aided_config():
config = read_config() config = read_config()
llm_aided_config = config.get('llm-aided-config') llm_aided_config = config.get('llm-aided-config')
if llm_aided_config is None: if llm_aided_config is None:
logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default") # logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None return None
else: else:
return llm_aided_config return llm_aided_config
def get_latex_delimiter_config():
config = read_config()
latex_delimiter_config = config.get('latex-delimiter-config')
if latex_delimiter_config is None:
logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None
else:
return latex_delimiter_config
if __name__ == '__main__': def get_local_models_dir():
ak, sk, endpoint = get_s3_config('llm-raw') config = read_config()
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use None as default")
return models_dir
\ No newline at end of file
from loguru import logger
from .pdf_image_tools import cut_image
def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, image_writer, scale=2):
def return_path(path_type):
return f"{path_type}/{page_img_md5}"
span_type = span["type"]
if not check_img_bbox(span["bbox"]) or not image_writer:
span["image_path"] = ""
else:
span["image_path"] = cut_image(
span["bbox"], page_id, page_pil_img, return_path=return_path(span_type), image_writer=image_writer, scale=scale
)
return span
def check_img_bbox(bbox) -> bool:
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"image_bboxes: 错误的box, {bbox}")
return False
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment