Unverified Commit 0c7a0882 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2611 from myhloli/dev

Dev
parents 3bd0ecf1 a392f445
# Copyright (c) Opendatalab. All rights reserved.
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from loguru import logger
from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType
from magic_pdf.libs.config_reader import get_device
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
class RapidTableModel(object):
def __init__(self, ocr_engine, table_sub_model_name='slanet_plus'):
sub_model_list = [model.value for model in ModelType]
if table_sub_model_name is None:
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
else:
root_dir = Path(__file__).absolute().parent.parent.parent.parent.parent
slanet_plus_model_path = os.path.join(root_dir, 'resources', 'slanet_plus', 'slanet-plus.onnx')
input_args = RapidTableInput(model_type=table_sub_model_name, model_path=slanet_plus_model_path)
else:
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
def __init__(self, ocr_engine):
slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
self.table_model = RapidTable(input_args)
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
# from rapidocr_paddle import RapidOCR
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
# else:
# from rapidocr_onnxruntime import RapidOCR
# self.ocr_engine = RapidOCR()
# self.ocr_model_name = "PaddleOCR"
self.ocr_engine = ocr_engine
......
from transformers import AutoConfig, AutoImageProcessor, AutoModelForCausalLM
from .configuration_mineru2 import Mineru2QwenConfig
from .image_processing_mineru2 import Mineru2ImageProcessor
from .modeling_mineru2 import Mineru2QwenForCausalLM
AutoConfig.register(Mineru2QwenConfig.model_type, Mineru2QwenConfig)
AutoModelForCausalLM.register(Mineru2QwenConfig, Mineru2QwenForCausalLM)
AutoImageProcessor.register(Mineru2QwenConfig, slow_image_processor_class=Mineru2ImageProcessor)
from transformers import Qwen2Config
class Mineru2QwenConfig(Qwen2Config):
model_type = "mineru2_qwen"
def __init__(
self,
ignore_index=-100,
image_aspect_ratio="square_anyres_max_9",
image_grid_pinpoints="(1x1),...,(4x4)",
image_token_index=151646,
mm_hidden_size=1152,
mm_patch_merge_type="spatial_unpad",
mm_projector_type="mlp2x_gelu",
mm_vision_select_feature="full",
mm_vision_select_layer=-2,
mm_vision_tower="google/siglip-so400m-patch14-384",
tie_word_embeddings=False,
tokenizer_model_max_length=16384,
tokenizer_padding_side="right",
unfreeze_mm_vision_tower=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_aspect_ratio = image_aspect_ratio
self.image_grid_pinpoints = image_grid_pinpoints
self.image_token_index = image_token_index
self.mm_hidden_size = mm_hidden_size
self.mm_patch_merge_type = mm_patch_merge_type
self.mm_projector_type = mm_projector_type
self.mm_vision_select_feature = mm_vision_select_feature
self.mm_vision_select_layer = mm_vision_select_layer
self.mm_vision_tower = mm_vision_tower
self.tokenizer_model_max_length = tokenizer_model_max_length
self.tokenizer_padding_side = tokenizer_padding_side
self.unfreeze_mm_vision_tower = unfreeze_mm_vision_tower
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
import ast
import math
import re
from functools import partial, reduce
from typing import Dict, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
to_numpy_array,
)
from transformers.utils import TensorType
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
original_width, original_height = original_size
best_fit = (0, 0)
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def divide_to_patches(image, patch_size):
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
if pil_img.mode == "L":
pil_img = pil_img.convert("RGB")
if width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
# This functions is not used.
def resize_and_pad_image(image, target_resolution):
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
# DIFFERENT from sglang.srt.mm_utils.process_anyres_image
def process_anyres_image(image, processor, grid_pinpoints):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
patch_size = processor.crop_size["height"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
best_resolution = select_best_resolution(image.size, possible_resolutions)
# image_padded = resize_and_pad_image(image, best_resolution)
image_padded = image.resize(best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size["height"])
image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", "")
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
else:
return image_processor(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
class Mineru2ImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
size=(384, 384),
crop_size: Optional[Dict[str, int]] = None,
resample=PILImageResampling.BICUBIC,
rescale_factor=1 / 255,
data_format=ChannelDimension.FIRST,
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[list] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
self.image_aspect_ratio = image_aspect_ratio
self.image_grid_pinpoints = image_grid_pinpoints
self.in_e2e_processing = False
def _preprocess(self, images):
if isinstance(images, Image.Image):
images = [images]
else:
# to adapt video data
images = [to_numpy_array(image) for image in images]
assert isinstance(images, list)
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
return {"pixel_values": images}
def _preprocess_end_to_end(self, images):
image_aspect_ratio = self.image_aspect_ratio
image_grid_pinpoints = self.image_grid_pinpoints
assert image_aspect_ratio is not None
assert image_grid_pinpoints is not None
pixel_values = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in self.image_mean))
image = self._preprocess(image)["pixel_values"][0]
pixel_values.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, self, self.image_grid_pinpoints)
pixel_values.append(image.numpy())
else:
pixel_values = self._preprocess(images)["pixel_values"]
if isinstance(pixel_values, list) and all(x.shape == pixel_values[0].shape for x in pixel_values):
pixel_values = np.stack(pixel_values, axis=0)
# CAUTION: here used (height, width).
image_sizes = [(image.height, image.width) for image in images]
assert len(pixel_values) == len(image_sizes)
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
def preprocess(
self,
images,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
if self.image_aspect_ratio is None or self.in_e2e_processing:
data = self._preprocess(images)
else:
assert self.image_grid_pinpoints is not None
self.in_e2e_processing = True
try:
data = self._preprocess_end_to_end(images)
finally:
self.in_e2e_processing = False
return BatchFeature(data=data, tensor_type=return_tensors)
import math
import re
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import (
Qwen2ForCausalLM,
Qwen2Model,
SiglipVisionConfig,
SiglipVisionModel,
)
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mineru2 import Mineru2QwenConfig
from .image_processing_mineru2 import Mineru2ImageProcessor, get_anyres_image_grid_shape
class SiglipVisionTower(nn.Module):
def __init__(self, vision_tower):
super().__init__()
self.config = SiglipVisionConfig.from_pretrained(vision_tower)
assert isinstance(self.config, SiglipVisionConfig)
self.config.num_hidden_layers -= 1 # drop the last hidden layer
self.config.vision_use_head = False
self.vision_tower = SiglipVisionModel(self.config)
self.vision_tower.requires_grad_(False)
self.image_processor = Mineru2ImageProcessor()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
)
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def image_size(self):
return self.config.image_size
def build_vision_tower(config: Mineru2QwenConfig):
vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
model_path = getattr(config, "_name_or_path", "")
if "siglip" in vision_tower.lower():
if model_path:
return SiglipVisionTower(f"{model_path}/{vision_tower}")
else:
return SiglipVisionTower(vision_tower)
raise ValueError(f"Unknown vision tower: {vision_tower}")
def build_vision_projector(config: Mineru2QwenConfig):
projector_type = getattr(config, "mm_projector_type", "linear")
if projector_type == "linear":
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU()) # type: ignore
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == "identity":
return nn.Identity()
raise ValueError(f"Unknown projector type: {projector_type}")
class Mineru2QwenModel(Qwen2Model):
config_class = Mineru2QwenConfig
def __init__(self, config: Mineru2QwenConfig):
super(Mineru2QwenModel, self).__init__(config)
self.vision_tower = build_vision_tower(config)
self.mm_projector = build_vision_projector(config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
class Mineru2QwenForCausalLM(Qwen2ForCausalLM):
config_class = Mineru2QwenConfig
def __init__(self, config: Mineru2QwenConfig):
super(Qwen2ForCausalLM, self).__init__(config)
config.rope_scaling = None
self.model = Mineru2QwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.ignore_index = config.ignore_index
self.image_token_index = config.image_token_index
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def encode_images(self, images: torch.Tensor):
image_features = self.get_model().vision_tower(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None
):
vision_tower = self.get_model().vision_tower
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
if mm_patch_merge_type == "flat":
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_model().vision_tower.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r"square_anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.get_model().vision_tower.config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
raise NotImplementedError
if (
"unpad" in mm_patch_merge_type
and "anyres_max" in image_aspect_ratio
and matched_anyres_max_num_patches
):
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature, [int(h // times), int(w // times)], mode="bilinear"
)[0]
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif "unpad" in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if "unpad" in mm_patch_merge_type:
image_feature = torch.cat(
(image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0
)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
else:
image_features = self.encode_images(images)
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, self.ignore_index)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == self.image_token_index).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = (
[-1] + torch.where(cur_input_ids == self.image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],), self.ignore_index, device=cur_labels.device, dtype=cur_labels.dtype
)
)
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len), self.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device
)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = (
self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes
)
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[List[List[int]]] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal(
inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
from sglang.srt.configs.model_config import multimodal_model_archs
from sglang.srt.models.registry import ModelRegistry
try:
# sglang==0.4.5.post3
from sglang.srt.managers.multimodal_processor import (
PROCESSOR_MAPPING as PROCESSOR_MAPPING,
)
except ImportError:
# sglang==0.4.4.post1
from sglang.srt.managers.image_processor import (
IMAGE_PROCESSOR_MAPPING as PROCESSOR_MAPPING,
)
from .. import vlm_hf_model as _
from .image_processor import Mineru2ImageProcessor
from .model import Mineru2QwenForCausalLM
ModelRegistry.models[Mineru2QwenForCausalLM.__name__] = Mineru2QwenForCausalLM
PROCESSOR_MAPPING[Mineru2QwenForCausalLM] = Mineru2ImageProcessor
multimodal_model_archs.append(Mineru2QwenForCausalLM.__name__)
import asyncio
import time
from types import MethodType
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
import fastapi
from sglang.srt.entrypoints.engine import Engine as _Engine
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.tokenizer_manager import (
TokenizerManager,
dataclass_to_string_truncated,
logger,
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from ...utils.run_async import run_async
from .logit_processor import Mineru2LogitProcessor
class BatchEngine(_Engine):
"""
The engine is patched to support batch multi-modal generate, and early image preprocessing.
"""
def __init__(self, server_args: ServerArgs, **kwargs):
server_args.enable_custom_logit_processor = True
super().__init__(server_args=server_args, **kwargs)
_patch_tokenizer_manager(self.tokenizer_manager)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
# EDIT
if isinstance(image_data, list):
for _ in range(len(image_data)):
modalities_list.append(["image"])
elif image_data is not None:
modalities_list.append("image")
# ADD
if custom_logit_processor is None:
custom_logit_processor = Mineru2LogitProcessor().to_str()
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
generator = _generate_request(self.tokenizer_manager, obj, None)
if stream:
def generator_wrapper():
while True:
try:
chunk = run_async(generator.__anext__())
yield chunk
except StopAsyncIteration:
break
return generator_wrapper()
else:
ret = run_async(generator.__anext__())
return ret
async def async_generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, AsyncIterator[Dict], Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
# EDIT
if isinstance(image_data, list):
for _ in range(len(image_data)):
modalities_list.append(["image"])
elif image_data is not None:
modalities_list.append("image")
# ADD
if custom_logit_processor is None:
custom_logit_processor = Mineru2LogitProcessor().to_str()
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
generator = _generate_request(self.tokenizer_manager, obj, None)
if stream is True:
return generator
else:
return await generator.__anext__()
def _auto_create_handle_loop(self: TokenizerManager):
"""
patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
when the event loop changes.
"""
try:
curr_handle_loop = asyncio.get_running_loop()
except RuntimeError:
curr_handle_loop = None
last_handle_loop = getattr(self, "_last_handle_loop", None)
if last_handle_loop != curr_handle_loop:
self.no_create_loop = False
setattr(self, "_last_handle_loop", curr_handle_loop)
return TokenizerManager.auto_create_handle_loop(self)
def _patch_tokenizer_manager(self: TokenizerManager):
self.auto_create_handle_loop = MethodType(_auto_create_handle_loop, self)
async def _one_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request],
created_time: Optional[float],
):
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for out in self._wait_one_response(obj, request):
yield out
async def _handle_batch_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
):
batch_size = obj.batch_size
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) != 1:
raise Exception("parallel_sample_num != 1 is not supported in this patched code.")
# Send all requests
for i in range(batch_size):
tmp_obj = obj[i]
generators.append(_one_request(self, tmp_obj, request, created_time))
rids.append(tmp_obj.rid)
# Wait for all requests
is_stream = hasattr(obj, "stream") and obj.stream
if not is_stream:
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
yield outputs
else:
rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map:
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
gen = task_map.pop(task)
try:
result = task.result()
result["index"] = rid_to_index[result["meta_info"]["id"]]
yield result
new_task = asyncio.create_task(gen.__anext__())
task_map[new_task] = gen
except StopAsyncIteration:
pass
async def _generate_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
self.auto_create_handle_loop()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj.normalize_batch_and_arguments()
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}")
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, request):
yield response
else:
async for response in _handle_batch_request(self, obj, request, created_time):
yield response
import ast
import asyncio
import re
from typing import List, Optional, Union
import numpy as np
try:
# sglang==0.4.5.post3
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as BaseProcessor,
)
get_global_processor = None
except ImportError:
# sglang==0.4.4.post1
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor as BaseProcessor,
get_global_processor,
)
from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
from .model import Mineru2QwenForCausalLM
# image_best_res is only resized (not padded).
def process_anyres_image(image, processor, grid_pinpoints):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
patch_size = processor.crop_size["height"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_best_res = image.resize(best_resolution) # <<<<<<< Here changed
patches = divide_to_patches(image_best_res, processor.crop_size["height"])
image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
return np.stack(image_patches, axis=0)
class Mineru2ImageProcessor(BaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
if image_processor is None:
assert get_global_processor is not None
image_processor = get_global_processor().image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
else:
pixel_values = image_processor(image)["pixel_values"][0]
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
if hasattr(self, "cpu_executor"):
executor = self.cpu_executor
else:
executor = self.executor
if get_global_processor is not None:
image_processor = None # save ipc cost
else:
image_processor = self._processor.image_processor
if executor is not None:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
executor,
Mineru2ImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
image_processor,
)
else:
return self._process_single_image_task(
image_data,
aspect_ratio,
grid_pinpoints,
image_processor,
)
# sglang==0.4.4.post1
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", "")
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints") and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(self._process_single_image(img_data, aspect_ratio, grid_pinpoints))
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
# sglang==0.4.5.post3
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
result = await self.process_images_async(image_data, input_text, request_obj, *args, **kwargs)
if result is None:
return None
modality = Modality.IMAGE
if isinstance(request_obj.modalities, list):
if request_obj.modalities[0] == "multi-images":
modality = Modality.MULTI_IMAGES
elif request_obj.modalities[0] == "video":
modality = Modality.VIDEO
return {
"mm_items": [
MultimodalDataItem(
pixel_values=result["pixel_values"],
image_sizes=result["image_sizes"],
modality=modality,
)
],
}
ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}
from typing import List
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
class Mineru2LogitProcessor(CustomLogitProcessor):
"""
Stateless logit processor for Mineru2.
(base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
This processor applies token-level constraints to prevent repetition during generation.
It supports two main constraints:
- no_repeat_ngram_size (int):
Prevents repeating the same n-gram of specified size in the output.
Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
This implementation is slower due to its lack of specialized optimization.
- no_repeat_token_count (int):
(Placeholder for future logic)
Intended to prevent repeating the same token multiple times.
Not yet implemented in this version.
"""
def __init__(self) -> None:
super().__init__()
self._generated_ngrams = {} # Cache of generated n-grams by request ID
self._time = {} # Timestamp of the last update for each request
self._gen_step = 0 # Global generation step counter
def __call__(self, logits, batch_info: List[dict]):
"""
Applies repetition constraints to the logits before sampling tokens.
Args:
logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
- "__req__": Request object containing request ID and output_ids.
- "no_repeat_ngram_size": Size of n-gram to avoid repeating.
Returns:
FloatTensor: The modified logits tensor with banned token logits set to -inf.
"""
from sglang.srt.managers.schedule_batch import Req
self._gen_step += 1 # Update global generation step
for idx, info in enumerate(batch_info):
if not isinstance(info, dict) or "__req__" not in info:
continue
req: Req = info["__req__"]
rid = req.rid
output_ids = req.output_ids
ngram_size = info.get("no_repeat_ngram_size", 0)
# Skip if there are not enough tokens to form an n-gram
if ngram_size <= 0 or len(output_ids) < ngram_size:
continue
# Record the current step for cache cleanup tracking
self._time[rid] = self._gen_step
# Initialize n-gram cache for this request if it doesn't exist
if rid not in self._generated_ngrams:
self._generated_ngrams[rid] = {}
# Get the n-gram prefix (all but the last token)
prev_ngram = tuple(output_ids[-ngram_size:-1])
last_token = output_ids[-1]
# Store this n-gram occurrence
self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
# Get the next-token candidates to ban based on current prefix
current_prefix = tuple(output_ids[-ngram_size + 1 :])
banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
# Set the logits of banned tokens to negative infinity
for token in banned_tokens:
logits[idx][token] = -float("inf")
# Clean up cache for expired requests
expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
for rid in expired_rids:
self._generated_ngrams.pop(rid, None)
self._time.pop(rid, None)
return logits
import math
import re
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, # unpad_image, unpad_image_shape
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
from torch import nn
from transformers import (
CLIPVisionConfig,
CLIPVisionModel,
SiglipVisionConfig,
SiglipVisionModel,
)
from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
def flatten_nested_list(nested_list):
if isinstance(nested_list, list):
return [item for sublist in nested_list for item in flatten_nested_list(sublist)]
else:
return [nested_list]
def downgrade_modality(modality):
modality_str = str(modality)
if "MULTI_IMAGES" in modality_str:
return "multi-images"
if "IMAGE" in modality_str:
return "image"
if "VIDEO" in modality_str:
return "video"
if "AUDIO" in modality_str:
return "audio"
raise ValueError(f"Unexpected modality: {modality_str}")
class Mineru2QwenForCausalLM(nn.Module):
def __init__(
self,
config: Mineru2QwenConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 151646
# load vision tower
mm_vision_tower = self.config.mm_vision_tower
if "clip" in mm_vision_tower:
vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
self.vision_tower = CLIPVisionModel(vision_config) # type: ignore
elif "siglip" in mm_vision_tower:
vision_config = SiglipVisionConfig.from_pretrained(mm_vision_tower)
self.vision_tower = SiglipVisionModel(vision_config) # type: ignore
# Siglip needs all feature tokens
self.config.mm_vision_select_feature = "full"
else:
raise ValueError(f"Unexpected mm_vision_tower: {mm_vision_tower}")
### EDIT: change projector
# the name `projector` contains `proj` which is often used in attention layers, which can cause bugs in quantization.
self.multi_modal_mlp = build_vision_projector(config)
self.language_model = Qwen2ForCausalLM(
config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(torch.empty(config.hidden_size))
language_model_device = next(self.language_model.parameters()).device
self.vision_tower = self.vision_tower.to(language_model_device)
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
if self.vision_feature_select_strategy in ("patch", "full"):
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
def pad_input_ids(self, input_ids: List[int], image_inputs):
if hasattr(image_inputs, "mm_items"): # MultimodalInputs
# sglang==0.4.5.post3
image_sizes = flatten_nested_list([item.image_sizes for item in image_inputs.mm_items])
pad_values = [item.pad_value for item in image_inputs.mm_items]
else: # ImageInputs
# sglang==0.4.4.post1
image_sizes = image_inputs.image_sizes
pad_values = image_inputs.pad_values
# hardcode for spatial_unpad + anyres
# if image_inputs.modalities is not None and (
# "multi-images" in image_inputs.modalities or "video" in image_inputs.modalities
# ):
# image_aspect_ratio = "pad"
# else:
# image_aspect_ratio = "anyres"
offset_list = []
image_inputs.image_pad_len = []
for image_idx, image_s in enumerate(image_sizes):
if len(image_sizes) > 16:
# 2x2 pooling with stride 2
new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2
else:
new_image_feature_len = self.image_feature_len # multiimage
height = width = self.num_patches_per_side
if "anyres" in self.config.image_aspect_ratio:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_s,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
h = num_patch_height * height
w = num_patch_width * width
### EDIT: remove `unpad_image_shape`
# new_h, new_w = unpad_image_shape(h, w, image_s)
new_h, new_w = h, w
if "anyres_max" in self.config.image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", self.config.image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len))
if times > 1.1:
new_h = int(new_h // times)
new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1)
try:
offset = input_ids.index(self.config.image_token_index)
except ValueError:
offset = 0
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = input_ids[:offset] + [pad_values[image_idx]] * new_image_feature_len + input_ids[offset + 1 :]
offset_list.append(offset)
image_inputs.image_pad_len.append(new_image_feature_len)
image_inputs.image_offsets = offset_list
return input_ids
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
pixel_values = pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype)
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
image_features = self.multi_modal_mlp(selected_image_feature)
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hasattr(forward_batch, "mm_inputs"):
# sglang==0.4.5.post3
image_inputs = forward_batch.mm_inputs
is_sglang_mm_inputs = True
else:
# sglang==0.4.4.post1
image_inputs = forward_batch.image_inputs
is_sglang_mm_inputs = False
if image_inputs is None:
image_inputs = []
if forward_batch.forward_mode.is_extend():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
max_image_offset = []
for im in image_inputs:
if im:
if hasattr(im, "mm_items"):
# sglang==0.4.5.post3
modalities_list.extend([downgrade_modality(item.modality) for item in im.mm_items])
elif im.modalities is not None:
# sglang==0.4.4.post1
modalities_list.extend(im.modalities)
if im and im.image_offsets:
max_image_offset.append(np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)))
else:
max_image_offset.append(-1)
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
bs = forward_batch.batch_size
if is_sglang_mm_inputs:
# sglang==0.4.5.post3
pixel_values = flatten_nested_list(
[[item.pixel_values for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i]]
) # image_inputs[batch_idx].mm_items[item_idx].pixel_values is Tensor
image_sizes = [
flatten_nested_list([item.image_sizes for item in image_inputs[i].mm_items])
for i in range(bs)
if need_vision[i]
] # image_inputs[batch_idx].mm_items[item_idx].image_sizes should be tuple, but is list of tuple for now.
else:
# sglang==0.4.4.post1
pixel_values = [image_inputs[i].pixel_values for i in range(bs) if need_vision[i]]
image_sizes = [image_inputs[i].image_sizes for i in range(bs) if need_vision[i]]
########## Encode Image ########
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(np.array(pixel_values), device=self.vision_tower.device)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096
if self.mm_patch_merge_type.startswith("spatial"):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if modalities_list[image_idx] == "image":
image_aspect_ratio = self.config.image_aspect_ratio # single image
elif modalities_list[image_idx] == "multi-images" or modalities_list[image_idx] == "video":
image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = (
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
# )
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image"
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r".*anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
vision_tower_image_size = self.image_size
try:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx][0],
self.config.image_grid_pinpoints,
vision_tower_image_size,
)
except Exception as e:
print(f"Error: {e}")
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
image_feature = image_feature.view(2, 2, height, width, -1)
if "unpad" in self.mm_patch_merge_type:
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
### EDIT: remove `unpad_image`
# image_feature = unpad_image(image_feature, image_sizes[image_idx][0])
if "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(
image_feature,
[int(h // times), int(w // times)],
mode="bilinear",
)[0]
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
image_feature = image_feature.unsqueeze(0)
else:
if modalities_list[image_idx] == "video": # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(num_of_frames, height, width, -1)
image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # N, C, H, W
height, weight = image_feature.shape[2:]
scaled_shape = [
math.ceil(height / 2),
math.ceil(weight / 2),
]
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode="bilinear")
image_feature = image_feature.flatten(2).transpose(1, 2).contiguous() # N, C, H*W
if "unpad" in self.mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
# Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
self.language_model.model.image_newline[None, None].expand(
image_feature.shape[0],
1,
image_feature.shape[-1],
),
),
dim=1,
)
new_image_features.append(image_feature)
image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
seq_len = extend_seq_lens[i]
prefix_len = prefix_lens_cpu[i]
# Multiple images
for image_idx, image_offset in enumerate(image_inputs[i].image_offsets):
if image_offset + image_inputs[i].image_pad_len[image_idx] <= prefix_len:
continue
if image_offset >= prefix_len + seq_len:
break
tmp_image_feature = image_features[pt][image_idx]
pad_len = tmp_image_feature.shape[0]
input_offset = image_offset - prefix_len
left_idx = start_idx + input_offset
right_idx = left_idx + pad_len
assert right_idx > start_idx
if input_offset < 0:
left_idx = start_idx
tmp_image_feature = tmp_image_feature[-input_offset:]
if right_idx > start_idx + seq_len:
tmp_image_feature = tmp_image_feature[: start_idx + seq_len - right_idx]
right_idx = start_idx + seq_len
try:
input_embeds[left_idx:right_idx] = tmp_image_feature
except RuntimeError as e:
print(f"RuntimeError in image encoding: {e}")
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
print(f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}")
pt += 1
return self.language_model(input_ids, positions, forward_batch, input_embeds=input_embeds)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
else:
raise ValueError(f"Unexpected forward mode: {forward_batch.forward_mode}")
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
projector_weights = {
"model.mm_projector": "multi_modal_mlp",
"model.vision_tower.vision_tower": "vision_tower",
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "projector" in name or "vision_tower" in name or "image_newline" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
else:
self.language_model.load_weights([(name, loaded_weight)])
@property
def num_patches_per_side(self):
return self.image_size // self.patch_size
EntryClass = [Mineru2QwenForCausalLM]
import os
import sys
from fastapi import Request
from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from .logit_processor import Mineru2LogitProcessor
_custom_logit_processor_str = Mineru2LogitProcessor().to_str()
# remote the existing /generate route
for route in app.routes[:]:
if hasattr(route, "path") and getattr(route, "path") == "/generate":
app.routes.remove(route)
# add the custom /generate route
@app.api_route("/generate", methods=["POST", "PUT"])
async def custom_generate_request(obj: GenerateReqInput, request: Request):
if obj.custom_logit_processor is None:
obj.custom_logit_processor = _custom_logit_processor_str
return await generate_request(obj, request)
def main():
server_args = prepare_server_args(sys.argv[1:])
if server_args.chat_template is None:
server_args.chat_template = "chatml"
server_args.enable_custom_logit_processor = True
if server_args.model_path is None:
server_args.model_path = auto_download_and_get_model_root_path("/","vlm")
try:
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
if __name__ == "__main__":
main()
# Copyright (c) Opendatalab. All rights reserved.
from magic_pdf.config.ocr_content_type import BlockType
from magic_pdf.libs.boxbase import (
# Copyright (c) Opendatalab. All rights reserved.
from mineru.utils.boxbase import (
calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio,
calculate_vertical_projection_overlap_ratio,
get_minbox_if_overlap_by_ratio
)
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
from mineru.utils.enum_class import BlockType
def add_bboxes(blocks, block_type, bboxes):
for block in blocks:
x0, y0, x1, y1 = block['bbox']
if block_type in [
BlockType.ImageBody,
BlockType.ImageCaption,
BlockType.ImageFootnote,
BlockType.TableBody,
BlockType.TableCaption,
BlockType.TableFootnote,
]:
bboxes.append(
[
x0,
y0,
x1,
y1,
None,
None,
None,
block_type,
None,
None,
None,
None,
block['score'],
block['group_id'],
]
)
def process_groups(groups, body_key, caption_key, footnote_key):
body_blocks = []
caption_blocks = []
footnote_blocks = []
maybe_text_image_blocks = []
for i, group in enumerate(groups):
if body_key == 'image_body' and len(group[caption_key]) == 0 and len(group[footnote_key]) == 0:
# 如果没有caption和footnote,则不需要将group_id添加到image_body中
group[body_key]['group_id'] = i
maybe_text_image_blocks.append(group[body_key])
continue
else:
bboxes.append(
[
x0,
y0,
x1,
y1,
None,
None,
None,
block_type,
None,
None,
None,
None,
block['score'],
]
)
def ocr_prepare_bboxes_for_layout_split_v2(
group[body_key]['group_id'] = i
body_blocks.append(group[body_key])
for caption_block in group[caption_key]:
caption_block['group_id'] = i
caption_blocks.append(caption_block)
for footnote_block in group[footnote_key]:
footnote_block['group_id'] = i
footnote_blocks.append(footnote_block)
return body_blocks, caption_blocks, footnote_blocks, maybe_text_image_blocks
def prepare_block_bboxes(
img_body_blocks,
img_caption_blocks,
img_footnote_blocks,
......@@ -73,15 +47,15 @@ def ocr_prepare_bboxes_for_layout_split_v2(
):
all_bboxes = []
add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
add_bboxes(text_blocks, BlockType.Text, all_bboxes)
add_bboxes(title_blocks, BlockType.Title, all_bboxes)
add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
add_bboxes(img_body_blocks, BlockType.IMAGE_BODY, all_bboxes)
add_bboxes(img_caption_blocks, BlockType.IMAGE_CAPTION, all_bboxes)
add_bboxes(img_footnote_blocks, BlockType.IMAGE_CAPTION, all_bboxes)
add_bboxes(table_body_blocks, BlockType.TABLE_BODY, all_bboxes)
add_bboxes(table_caption_blocks, BlockType.TABLE_CAPTION, all_bboxes)
add_bboxes(table_footnote_blocks, BlockType.TABLE_FOOTNOTE, all_bboxes)
add_bboxes(text_blocks, BlockType.TEXT, all_bboxes)
add_bboxes(title_blocks, BlockType.TITLE, all_bboxes)
add_bboxes(interline_equation_blocks, BlockType.INTERLINE_EQUATION, all_bboxes)
"""block嵌套问题解决"""
"""文本框与标题框重叠,优先信任文本框"""
......@@ -97,7 +71,7 @@ def ocr_prepare_bboxes_for_layout_split_v2(
"""discarded_blocks"""
all_discarded_blocks = []
add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
add_bboxes(discarded_blocks, BlockType.DISCARDED, all_discarded_blocks)
"""footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半30%区域的"""
footnote_blocks = []
......@@ -122,63 +96,31 @@ def ocr_prepare_bboxes_for_layout_split_v2(
return all_bboxes, all_discarded_blocks, footnote_blocks
def find_blocks_under_footnote(all_bboxes, footnote_blocks):
need_remove_blocks = []
for block in all_bboxes:
block_x0, block_y0, block_x1, block_y1 = block[:4]
for footnote_bbox in footnote_blocks:
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if (
block_y0 >= footnote_y1
and calculate_vertical_projection_overlap_ratio(
(block_x0, block_y0, block_x1, block_y1), footnote_bbox
)
>= 0.8
):
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block
text_blocks = []
for block in all_bboxes:
if block[7] == BlockType.Text:
text_blocks.append(block)
interline_equation_blocks = []
for block in all_bboxes:
if block[7] == BlockType.InterlineEquation:
interline_equation_blocks.append(block)
need_remove = []
for interline_equation_block in interline_equation_blocks:
for text_block in text_blocks:
interline_equation_block_bbox = interline_equation_block[:4]
text_block_bbox = text_block[:4]
if calculate_iou(interline_equation_block_bbox, text_block_bbox) > 0.8:
if text_block not in need_remove:
need_remove.append(text_block)
if len(need_remove) > 0:
for block in need_remove:
all_bboxes.remove(block)
return all_bboxes
def add_bboxes(blocks, block_type, bboxes):
for block in blocks:
x0, y0, x1, y1 = block['bbox']
if block_type in [
BlockType.IMAGE_BODY,
BlockType.IMAGE_CAPTION,
BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_BODY,
BlockType.TABLE_CAPTION,
BlockType.TABLE_FOOTNOTE,
]:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block['score'], block['group_id']])
else:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block['score']])
def fix_text_overlap_title_blocks(all_bboxes):
# 先提取所有text和title block
text_blocks = []
for block in all_bboxes:
if block[7] == BlockType.Text:
if block[7] == BlockType.TEXT:
text_blocks.append(block)
title_blocks = []
for block in all_bboxes:
if block[7] == BlockType.Title:
if block[7] == BlockType.TITLE:
title_blocks.append(block)
need_remove = []
......@@ -219,6 +161,54 @@ def remove_need_drop_blocks(all_bboxes, discarded_blocks):
return all_bboxes
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block
text_blocks = []
for block in all_bboxes:
if block[7] == BlockType.TEXT:
text_blocks.append(block)
interline_equation_blocks = []
for block in all_bboxes:
if block[7] == BlockType.INTERLINE_EQUATION:
interline_equation_blocks.append(block)
need_remove = []
for interline_equation_block in interline_equation_blocks:
for text_block in text_blocks:
interline_equation_block_bbox = interline_equation_block[:4]
text_block_bbox = text_block[:4]
if calculate_iou(interline_equation_block_bbox, text_block_bbox) > 0.8:
if text_block not in need_remove:
need_remove.append(text_block)
if len(need_remove) > 0:
for block in need_remove:
all_bboxes.remove(block)
return all_bboxes
def find_blocks_under_footnote(all_bboxes, footnote_blocks):
need_remove_blocks = []
for block in all_bboxes:
block_x0, block_y0, block_x1, block_y1 = block[:4]
for footnote_bbox in footnote_blocks:
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if (
block_y0 >= footnote_y1
and calculate_vertical_projection_overlap_ratio(
(block_x0, block_y0, block_x1, block_y1), footnote_bbox
)
>= 0.8
):
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def remove_overlaps_min_blocks(all_bboxes):
# 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
# 删除重叠blocks中较小的那些
......@@ -254,4 +244,4 @@ def remove_overlaps_min_blocks(all_bboxes):
for block in need_remove:
all_bboxes.remove(block)
return all_bboxes
return all_bboxes
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import copy
import os
import statistics
import warnings
from typing import List
import torch
from loguru import logger
from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import BlockType, ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
"""获取所有line并计算正文line的高度"""
line_height = get_line_height(blocks)
"""获取所有line并对line排序"""
sorted_bboxes = sort_lines_by_model(blocks, page_w, page_h, line_height, footnote_blocks)
"""根据line的中位数算block的序列关系"""
blocks = cal_block_index(blocks, sorted_bboxes)
"""将image和table的block还原回group形式参与后续流程"""
blocks = revert_group_blocks(blocks)
"""重排block"""
sorted_blocks = sorted(blocks, key=lambda b: b['index'])
"""block内重排(img和table的block内多个caption或footnote的排序)"""
for block in sorted_blocks:
if block['type'] in [BlockType.IMAGE, BlockType.TABLE]:
block['blocks'] = sorted(block['blocks'], key=lambda b: b['index'])
return sorted_blocks
def get_line_height(blocks):
page_line_height_list = []
for block in blocks:
if block['type'] in [
BlockType.TEXT, BlockType.TITLE,
BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
]:
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
return statistics.median(page_line_height_list)
else:
return 10
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height, footnote_blocks):
page_line_list = []
def add_lines_to_block(b):
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
b['lines'] = []
for line_bbox in line_bboxes:
b['lines'].append({'bbox': line_bbox, 'spans': []})
page_line_list.extend(line_bboxes)
for block in fix_blocks:
if block['type'] in [
BlockType.TEXT, BlockType.TITLE,
BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
]:
if len(block['lines']) == 0:
add_lines_to_block(block)
elif block['type'] in [BlockType.TITLE] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
block['real_lines'] = copy.deepcopy(block['lines'])
add_lines_to_block(block)
else:
for line in block['lines']:
bbox = line['bbox']
page_line_list.append(bbox)
elif block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
block['real_lines'] = copy.deepcopy(block['lines'])
add_lines_to_block(block)
for block in footnote_blocks:
footnote_block = {'bbox': block[:4]}
add_lines_to_block(footnote_block)
if len(page_line_list) > 200: # layoutreader最高支持512line
return None
# 使用layoutreader排序
x_scale = 1000.0 / page_w
y_scale = 1000.0 / page_h
boxes = []
# logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
for left, top, right, bottom in page_line_list:
if left < 0:
logger.warning(
f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
left = 0
if right > page_w:
logger.warning(
f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
right = page_w
if top < 0:
logger.warning(
f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
top = 0
if bottom > page_h:
logger.warning(
f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
bottom = page_h
left = round(left * x_scale)
top = round(top * y_scale)
right = round(right * x_scale)
bottom = round(bottom * y_scale)
assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}' # noqa: E126, E121
boxes.append([left, top, right, bottom])
model_manager = ModelSingleton()
model = model_manager.get_model('layoutreader')
with torch.no_grad():
orders = do_predict(boxes, model)
sorted_bboxes = [page_line_list[i] for i in orders]
return sorted_bboxes
def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
# block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
x0, y0, x1, y1 = block_bbox
block_height = y1 - y0
block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox
if line_height * 2 < block_height:
if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
lines = int(block_height / line_height)
else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4:
lines = 3
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height)
else: # 判断长宽比
if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行
lines = 2
line_height = (y1 - y0) / lines
# 确定从哪个y位置开始绘制线条
current_y = y0
# 用于存储线条的位置信息[(x0, y), ...]
lines_positions = []
for i in range(lines):
lines_positions.append([x0, current_y, x1, current_y + line_height])
current_y += line_height
return lines_positions
else:
return [[x0, y0, x1, y1]]
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
device_name = get_device()
bf_16_support = False
if device_name.startswith("cuda"):
bf_16_support = torch.cuda.is_bf16_supported()
elif device_name.startswith("mps"):
bf_16_support = True
device = torch.device(device_name)
if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.layout_reader), ModelPath.layout_reader)
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
)
else:
logger.warning(
'local layoutreader model not exists, use online model from huggingface'
)
model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader'
)
if bf_16_support:
model.to(device).eval().bfloat16()
else:
model.to(device).eval()
else:
logger.error('model name not allow')
exit(1)
return model
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, model_name: str):
if model_name not in self._models:
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]
def do_predict(boxes: List[List[int]], model) -> List[int]:
from mineru.model.reading_order.layout_reader import (
boxes2inputs, parse_logits, prepare_inputs)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0)
return parse_logits(logits, len(boxes))
def cal_block_index(fix_blocks, sorted_bboxes):
if sorted_bboxes is not None:
# 使用layoutreader排序
for block in fix_blocks:
line_index_list = []
if len(block['lines']) == 0:
block['index'] = sorted_bboxes.index(block['bbox'])
else:
for line in block['lines']:
line['index'] = sorted_bboxes.index(line['bbox'])
line_index_list.append(line['index'])
median_value = statistics.median(line_index_list)
block['index'] = median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
if 'real_lines' in block:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
else:
# 使用xycut排序
block_bboxes = []
for block in fix_blocks:
# 如果block['bbox']任意值小于0,将其置为0
block['bbox'] = [max(0, x) for x in block['bbox']]
block_bboxes.append(block['bbox'])
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
if 'real_lines' in block:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
import numpy as np
from mineru.model.reading_order.xycut import recursive_xy_cut
random_boxes = np.array(block_bboxes)
np.random.shuffle(random_boxes)
res = []
recursive_xy_cut(np.asarray(random_boxes).astype(int), np.arange(len(block_bboxes)), res)
assert len(res) == len(block_bboxes)
sorted_boxes = random_boxes[np.array(res)].tolist()
for i, block in enumerate(fix_blocks):
block['index'] = sorted_boxes.index(block['bbox'])
# 生成line index
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
line_inedx = 1
for block in sorted_blocks:
for line in block['lines']:
line['index'] = line_inedx
line_inedx += 1
return fix_blocks
def revert_group_blocks(blocks):
image_groups = {}
table_groups = {}
new_blocks = []
for block in blocks:
if block['type'] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
group_id = block['group_id']
if group_id not in image_groups:
image_groups[group_id] = []
image_groups[group_id].append(block)
elif block['type'] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
group_id = block['group_id']
if group_id not in table_groups:
table_groups[group_id] = []
table_groups[group_id].append(block)
else:
new_blocks.append(block)
for group_id, blocks in image_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.IMAGE_BODY, BlockType.IMAGE))
for group_id, blocks in table_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.TABLE_BODY, BlockType.TABLE))
return new_blocks
def process_block_list(blocks, body_type, block_type):
indices = [block['index'] for block in blocks]
median_index = statistics.median(indices)
body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
return {
'type': block_type,
'bbox': body_bbox,
'blocks': blocks,
'index': median_index,
}
\ No newline at end of file
import math
def _is_in_or_part_overlap(box1, box2) -> bool:
"""两个bbox是否有部分重叠或者包含."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return not (x1_1 < x0_2 or # box1在box2的左边
x0_1 > x1_2 or # box1在box2的右边
y1_1 < y0_2 or # box1在box2的上边
y0_1 > y1_2) # box1在box2的下边
def _is_in_or_part_overlap_with_area_ratio(box1,
box2,
area_ratio_threshold=0.6):
"""判断box1是否在box2里面,或者box1和box2有部分重叠,且重叠面积占box1的比例超过area_ratio_threshold."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
if not _is_in_or_part_overlap(box1, box2):
return False
# 计算重叠面积
x_left = max(x0_1, x0_2)
y_top = max(y0_1, y0_2)
x_right = min(x1_1, x1_2)
y_bottom = min(y1_1, y1_2)
overlap_area = (x_right - x_left) * (y_bottom - y_top)
# 计算box1的面积
box1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
return overlap_area / box1_area > area_ratio_threshold
def _is_in(box1, box2) -> bool:
def is_in(box1, box2) -> bool:
"""box1是否完全在box2里面."""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return (x0_1 >= x0_2 and # box1的左边界不在box2的左边外
y0_1 >= y0_2 and # box1的上边界不在box2的上边外
x1_1 <= x1_2 and # box1的右边界不在box2的右边外
y1_1 <= y1_2) # box1的下边界不在box2的下边外
def _is_part_overlap(box1, box2) -> bool:
"""两个bbox是否有部分重叠,但不完全包含."""
if box1 is None or box2 is None:
return False
return _is_in_or_part_overlap(box1, box2) and not _is_in(box1, box2)
def _left_intersect(left_box, right_box):
"""检查两个box的左边界是否有交集,也就是left_box的右边界是否在right_box的左边界内."""
if left_box is None or right_box is None:
return False
x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box
return (
x0_1 >= x0_2 # box1的左边界不在box2的左边外
and y0_1 >= y0_2 # box1的上边界不在box2的上边外
and x1_1 <= x1_2 # box1的右边界不在box2的右边外
and y1_1 <= y1_2
) # box1的下边界不在box2的下边外
return x1_1 > x0_2 and x0_1 < x0_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def bbox_relative_pos(bbox1, bbox2):
"""判断两个矩形框的相对位置关系.
def _right_intersect(left_box, right_box):
"""检查box是否在右侧边界有交集,也就是left_box的左边界是否在right_box的右边界内."""
if left_box is None or right_box is None:
return False
x0_1, y0_1, x1_1, y1_1 = left_box
x0_2, y0_2, x1_2, y1_2 = right_box
return x0_1 < x1_2 and x1_1 > x1_2 and (y0_1 <= y0_2 <= y1_1
or y0_1 <= y1_2 <= y1_1)
def _is_vertical_full_overlap(box1, box2, x_torlence=2):
"""x方向上:要么box1包含box2, 要么box2包含box1。不能部分包含 y方向上:box1和box2有重叠."""
# 解析box的坐标
x11, y11, x12, y12 = box1 # 左上角和右下角的坐标 (x1, y1, x2, y2)
x21, y21, x22, y22 = box2
# 在x轴方向上,box1是否包含box2 或 box2包含box1
contains_in_x = (x11 - x_torlence <= x21 and x12 + x_torlence >= x22) or (
x21 - x_torlence <= x11 and x22 + x_torlence >= x12)
# 在y轴方向上,box1和box2是否有重叠
overlap_in_y = not (y12 < y21 or y11 > y22)
Args:
bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
bbox2: 一个四元组,表示第二个矩形框的左上角和右下角的坐标,格式为(x2, y2, x2b, y2b)
return contains_in_x and overlap_in_y
Returns:
一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
"""
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left = x2b < x1
right = x1b < x2
bottom = y2b < y1
top = y1b < y2
return left, right, bottom, top
def _is_bottom_full_overlap(box1, box2, y_tolerance=2):
"""检查box1下方和box2的上方有轻微的重叠,轻微程度收到y_tolerance的限制 这个函数和_is_vertical-
full_overlap的区别是,这个函数允许box1和box2在x方向上有轻微的重叠,允许一定的模糊度."""
if box1 is None or box2 is None:
return False
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
tolerance_margin = 2
is_xdir_full_overlap = (
(x0_1 - tolerance_margin <= x0_2 <= x1_1 + tolerance_margin
and x0_1 - tolerance_margin <= x1_2 <= x1_1 + tolerance_margin)
or (x0_2 - tolerance_margin <= x0_1 <= x1_2 + tolerance_margin
and x0_2 - tolerance_margin <= x1_1 <= x1_2 + tolerance_margin))
def bbox_distance(bbox1, bbox2):
"""计算两个矩形框的距离。
return y0_2 < y1_1 and 0 < (y1_1 -
y0_2) < y_tolerance and is_xdir_full_overlap
Args:
bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (tuple): 第二个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
Returns:
float: 矩形框之间的距离。
"""
def _is_left_overlap(
box1,
box2,
):
"""检查box1的左侧是否和box2有重叠 在Y方向上可以是部分重叠或者是完全重叠。不分box1和box2的上下关系,也就是无论box1在box2下
方还是box2在box1下方,都可以检测到重叠。 X方向上."""
def dist(point1, point2):
return math.sqrt((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2)
def __overlap_y(Ay1, Ay2, By1, By2):
return max(0, min(Ay2, By2) - max(Ay1, By1))
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
if box1 is None or box2 is None:
return False
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
if top and left:
return dist((x1, y1b), (x2b, y2))
elif left and bottom:
return dist((x1, y1), (x2b, y2b))
elif bottom and right:
return dist((x1b, y1), (x2, y2b))
elif right and top:
return dist((x1b, y1b), (x2, y2))
elif left:
return x1 - x2b
elif right:
return x2 - x1b
elif bottom:
return y1 - y2b
elif top:
return y2 - y1b
return 0.0
y_overlap_len = __overlap_y(y0_1, y1_1, y0_2, y1_2)
ratio_1 = 1.0 * y_overlap_len / (y1_1 - y0_1) if y1_1 - y0_1 != 0 else 0
ratio_2 = 1.0 * y_overlap_len / (y1_2 - y0_2) if y1_2 - y0_2 != 0 else 0
vertical_overlap_cond = ratio_1 >= 0.5 or ratio_2 >= 0.5
# vertical_overlap_cond = y0_1<=y0_2<=y1_1 or y0_1<=y1_2<=y1_1 or y0_2<=y0_1<=y1_2 or y0_2<=y1_1<=y1_2
return x0_1 <= x0_2 <= x1_1 and vertical_overlap_cond
def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
"""通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
if overlap_ratio > ratio:
if area1 <= area2:
return bbox1
else:
return bbox2
else:
return None
def __is_overlaps_y_exceeds_threshold(bbox1,
bbox2,
overlap_ratio_threshold=0.8):
"""检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
_, y0_1, _, y1_1 = bbox1
_, y0_2, _, y1_2 = bbox2
def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占最小面积的box的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
height1, height2 = y1_1 - y0_1, y1_2 - y0_2
# max_height = max(height1, height2)
min_height = min(height1, height2)
if x_right < x_left or y_bottom < y_top:
return 0.0
return (overlap / min_height) > overlap_ratio_threshold
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
(bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0
else:
return intersection_area / min_box_area
def calculate_iou(bbox1, bbox2):
......@@ -195,27 +148,6 @@ def calculate_iou(bbox1, bbox2):
return iou
def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占最小面积的box的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
intersection_area = (x_right - x_left) * (y_bottom - y_top)
min_box_area = min([(bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]),
(bbox2[3] - bbox2[1]) * (bbox2[2] - bbox2[0])])
if min_box_area == 0:
return 0
else:
return intersection_area / min_box_area
def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
......@@ -236,220 +168,6 @@ def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
return intersection_area / bbox1_area
def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
"""通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
if overlap_ratio > ratio:
if area1 <= area2:
return bbox1
else:
return bbox2
else:
return None
def get_bbox_in_boundary(bboxes: list, boundary: tuple) -> list:
x0, y0, x1, y1 = boundary
new_boxes = [
box for box in bboxes
if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1
]
return new_boxes
def is_vbox_on_side(bbox, width, height, side_threshold=0.2):
"""判断一个bbox是否在pdf页面的边缘."""
x0, x1 = bbox[0], bbox[2]
if x1 <= width * side_threshold or x0 >= width * (1 - side_threshold):
return True
return False
def find_top_nearest_text_bbox(pymu_blocks, obj_bbox):
tolerance_margin = 4
top_boxes = [
box for box in pymu_blocks
if obj_bbox[1] - box['bbox'][3] >= -tolerance_margin
and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
top_boxes = [
box for box in top_boxes if any([
obj_bbox[0] - tolerance_margin <= box['bbox'][0] <= obj_bbox[2] +
tolerance_margin, obj_bbox[0] -
tolerance_margin <= box['bbox'][2] <= obj_bbox[2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[0] <= box['bbox'][2] +
tolerance_margin, box['bbox'][0] -
tolerance_margin <= obj_bbox[2] <= box['bbox'][2] +
tolerance_margin
])
]
# 然后找到y1最大的那个
if len(top_boxes) > 0:
top_boxes.sort(key=lambda x: x['bbox'][3], reverse=True)
return top_boxes[0]
else:
return None
def find_bottom_nearest_text_bbox(pymu_blocks, obj_bbox):
bottom_boxes = [
box for box in pymu_blocks if box['bbox'][1] -
obj_bbox[3] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
bottom_boxes = [
box for box in bottom_boxes if any([
obj_bbox[0] - 2 <= box['bbox'][0] <= obj_bbox[2] + 2, obj_bbox[0] -
2 <= box['bbox'][2] <= obj_bbox[2] + 2, box['bbox'][0] -
2 <= obj_bbox[0] <= box['bbox'][2] + 2, box['bbox'][0] -
2 <= obj_bbox[2] <= box['bbox'][2] + 2
])
]
# 然后找到y0最小的那个
if len(bottom_boxes) > 0:
bottom_boxes.sort(key=lambda x: x['bbox'][1], reverse=False)
return bottom_boxes[0]
else:
return None
def find_left_nearest_text_bbox(pymu_blocks, obj_bbox):
"""寻找左侧最近的文本block."""
left_boxes = [
box for box in pymu_blocks if obj_bbox[0] -
box['bbox'][2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
left_boxes = [
box for box in left_boxes if any([
obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x1最大的那个
if len(left_boxes) > 0:
left_boxes.sort(key=lambda x: x['bbox'][2], reverse=True)
return left_boxes[0]
else:
return None
def find_right_nearest_text_bbox(pymu_blocks, obj_bbox):
"""寻找右侧最近的文本block."""
right_boxes = [
box for box in pymu_blocks if box['bbox'][0] -
obj_bbox[2] >= -2 and not _is_in(box['bbox'], obj_bbox)
]
# 然后找到X方向上有互相重叠的
right_boxes = [
box for box in right_boxes if any([
obj_bbox[1] - 2 <= box['bbox'][1] <= obj_bbox[3] + 2, obj_bbox[1] -
2 <= box['bbox'][3] <= obj_bbox[3] + 2, box['bbox'][1] -
2 <= obj_bbox[1] <= box['bbox'][3] + 2, box['bbox'][1] -
2 <= obj_bbox[3] <= box['bbox'][3] + 2
])
]
# 然后找到x0最小的那个
if len(right_boxes) > 0:
right_boxes.sort(key=lambda x: x['bbox'][0], reverse=False)
return right_boxes[0]
else:
return None
def bbox_relative_pos(bbox1, bbox2):
"""判断两个矩形框的相对位置关系.
Args:
bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
bbox2: 一个四元组,表示第二个矩形框的左上角和右下角的坐标,格式为(x2, y2, x2b, y2b)
Returns:
一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
"""
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left = x2b < x1
right = x1b < x2
bottom = y2b < y1
top = y1b < y2
return left, right, bottom, top
def bbox_distance(bbox1, bbox2):
"""计算两个矩形框的距离。
Args:
bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (tuple): 第二个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
Returns:
float: 矩形框之间的距离。
"""
def dist(point1, point2):
return math.sqrt((point1[0] - point2[0])**2 +
(point1[1] - point2[1])**2)
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
if top and left:
return dist((x1, y1b), (x2b, y2))
elif left and bottom:
return dist((x1, y1), (x2b, y2b))
elif bottom and right:
return dist((x1b, y1), (x2, y2b))
elif right and top:
return dist((x1b, y1b), (x2, y2))
elif left:
return x1 - x2b
elif right:
return x2 - x1b
elif bottom:
return y1 - y2b
elif top:
return y2 - y1b
return 0.0
def box_area(bbox):
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
def get_overlap_area(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
y_top = max(bbox1[1], bbox2[1])
x_right = min(bbox1[2], bbox2[2])
y_bottom = min(bbox1[3], bbox2[3])
if x_right < x_left or y_bottom < y_top:
return 0.0
# The area of overlap area
return (x_right - x_left) * (y_bottom - y_top)
def calculate_vertical_projection_overlap_ratio(block1, block2):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
......@@ -482,4 +200,4 @@ def calculate_vertical_projection_overlap_ratio(block1, block2):
# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return intersection_length / block1_length
return intersection_length / block1_length
\ No newline at end of file
"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
# Copyright (c) Opendatalab. All rights reserved.
import json
import os
import torch
from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
def read_config():
......@@ -20,11 +17,12 @@ def read_config():
config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
if not os.path.exists(config_file):
raise FileNotFoundError(f'{config_file} not found')
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
logger.warning(f'{config_file} not found, using default configuration')
return None
else:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_s3_config(bucket_name: str):
......@@ -55,85 +53,73 @@ def get_bucket_name(path):
return bucket
def get_local_models_dir():
config = read_config()
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return '/tmp/models'
else:
return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def parse_bucket_key(s3_full_path: str):
"""
输入 s3://bucket/path/to/my/file.txt
输出 bucket, path/to/my/file.txt
"""
s3_full_path = s3_full_path.strip()
if s3_full_path.startswith("s3://"):
s3_full_path = s3_full_path[5:]
if s3_full_path.startswith("/"):
s3_full_path = s3_full_path[1:]
bucket, key = s3_full_path.split("/", 1)
return bucket, key
def get_device():
config = read_config()
device = config.get('device-mode')
if device is None:
logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
return 'cpu'
device_mode = os.getenv('MINERU_DEVICE_MODE', None)
if device_mode is not None:
return device_mode
else:
return device
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def get_table_recog_config():
config = read_config()
table_config = config.get('table-config')
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
table_enable = os.getenv('MINERU_TABLE_ENABLE', None)
if table_enable is not None:
return json.loads(f'{{"enable": {table_enable}}}')
else:
return table_config
# logger.warning(f"not found 'MINERU_TABLE_ENABLE' in environment variable, use 'true' as default.")
return json.loads(f'{{"enable": true}}')
def get_layout_config():
config = read_config()
layout_config = config.get('layout-config')
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
def get_formula_config():
formula_enable = os.getenv('MINERU_FORMULA_ENABLE', None)
if formula_enable is not None:
return json.loads(f'{{"enable": {formula_enable}}}')
else:
return layout_config
# logger.warning(f"not found 'MINERU_FORMULA_ENABLE' in environment variable, use 'true' as default.")
return json.loads(f'{{"enable": true}}')
def get_formula_config():
def get_latex_delimiter_config():
config = read_config()
formula_config = config.get('formula-config')
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
latex_delimiter_config = config.get('latex-delimiter-config')
if latex_delimiter_config is None:
# logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None
else:
return formula_config
return latex_delimiter_config
def get_llm_aided_config():
config = read_config()
llm_aided_config = config.get('llm-aided-config')
if llm_aided_config is None:
logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
# logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None
else:
return llm_aided_config
def get_latex_delimiter_config():
config = read_config()
latex_delimiter_config = config.get('latex-delimiter-config')
if latex_delimiter_config is None:
logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None
else:
return latex_delimiter_config
if __name__ == '__main__':
ak, sk, endpoint = get_s3_config('llm-raw')
def get_local_models_dir():
config = read_config()
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use None as default")
return models_dir
\ No newline at end of file
from loguru import logger
from .pdf_image_tools import cut_image
def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, image_writer, scale=2):
def return_path(path_type):
return f"{path_type}/{page_img_md5}"
span_type = span["type"]
if not check_img_bbox(span["bbox"]) or not image_writer:
span["image_path"] = ""
else:
span["image_path"] = cut_image(
span["bbox"], page_id, page_pil_img, return_path=return_path(span_type), image_writer=image_writer, scale=scale
)
return span
def check_img_bbox(bbox) -> bool:
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"image_bboxes: 错误的box, {bbox}")
return False
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment