Commit 0fc1daac authored by luopl's avatar luopl
Browse files

Initial commit

parents
import cv2
from loguru import logger
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from .model_init import AtomModelSingleton
from ...utils.config_reader import get_formula_enable, get_table_enable
from ...utils.model_utils import crop_img, get_res_list_from_layout_res
from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
YOLO_LAYOUT_BASE_BATCH_SIZE = 8
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
self.batch_ratio = batch_ratio
self.formula_enable = get_formula_enable(formula_enable)
self.table_enable = get_table_enable(table_enable)
self.model_manager = model_manager
self.enable_ocr_det_batch = enable_ocr_det_batch
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
return []
images_layout_res = []
self.model = self.model_manager.get_model(
lang=None,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
atom_model_manager = AtomModelSingleton()
images = [image for image, _, _ in images_with_extra_info]
# doclayout_yolo
layout_images = []
for image_index, image in enumerate(images):
layout_images.append(image)
images_layout_res += self.model.layout_model.batch_predict(
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
if self.formula_enable:
# 公式检测
images_mfd_res = self.model.mfd_model.batch_predict(
images, MFD_BASE_BATCH_SIZE
)
# 公式识别
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
mfr_count = 0
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
# 清理显存
# clean_vram(self.model.device, vram_threshold=8)
ocr_res_list_all_page = []
table_res_list_all_page = []
for index in range(len(images)):
_, ocr_enable, _lang = images_with_extra_info[index]
layout_res = images_layout_res[index]
pil_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
'lang':_lang,
'ocr_enable':ocr_enable,
'pil_img':pil_img,
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
for table_res in table_res_list:
table_img, _ = crop_img(table_res, pil_img)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':table_img,
})
# OCR检测处理
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.3,
lang=lang
)
# 按分辨率分组并同时完成padding
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for crop_info in group_crops:
img = crop_info[0]
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None and len(dt_boxes) > 0:
# 直接应用原始OCR流程中的关键处理步骤
from mineru.utils.ocr_utils import (
merge_det_boxes, update_det_boxes, sorted_boxes
)
# 1. 排序检测框
if len(dt_boxes) > 0:
dt_boxes_sorted = sorted_boxes(dt_boxes)
else:
dt_boxes_sorted = []
# 2. 合并相邻检测框
if dt_boxes_sorted:
dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
else:
dt_boxes_merged = []
# 3. 根据公式位置更新检测框(关键步骤!)
if dt_boxes_merged and adjusted_mfdetrec_res:
dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
else:
dt_boxes_final = dt_boxes_merged
# 构造OCR结果格式
ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# 表格识别 table recognition
if self.table_enable:
for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_dict['lang']
table_model = atom_model_manager.get_atom_model(
atom_model_name='table',
lang=_lang,
)
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
if expected_ending:
table_res_dict['table_res']['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
'table recognition processing fails, not get html return'
)
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
lang = layout_res_item['lang']
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(f"{ocr_score:.3f}")
if ocr_score < OcrConfidence.min_confidence:
layout_res_item['category_id'] = 16
total_processed += len(img_crop_list)
return images_layout_res
import os
import torch
from loguru import logger
from .model_list import AtomicModel
from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from ...model.table.rapid_table import RapidTableModel
from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import auto_download_and_get_model_root_path
def table_model_init(lang=None):
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang
)
table_model = RapidTableModel(ocr_engine)
return table_model
def mfd_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model
def mfr_model_init(weight_dir, device='cpu'):
mfr_model = UnimernetModel(weight_dir, device)
return mfr_model
def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device)
return model
def ocr_model_init(det_db_box_thresh=0.3,
lang=None,
use_dilation=True,
det_db_unclip_ratio=1.8,
):
if lang is not None and lang != '':
model = PytorchPaddleOCR(
det_db_box_thresh=det_db_box_thresh,
lang=lang,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
)
else:
model = PytorchPaddleOCR(
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
)
return model
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get('lang', None)
table_model_name = kwargs.get('table_model_name', None)
if atom_model_name in [AtomicModel.OCR]:
key = (atom_model_name, lang)
elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name, lang)
else:
key = atom_model_name
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
atom_model = doclayout_yolo_model_init(
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get('mfr_weight_dir'),
kwargs.get('device')
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get('det_db_box_thresh'),
kwargs.get('lang'),
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get('lang'),
)
else:
logger.error('model name not allow')
exit(1)
if atom_model is None:
logger.error('model init failed')
exit(1)
else:
return atom_model
class MineruPipelineModel:
def __init__(self, **kwargs):
self.formula_config = kwargs.get('formula_config')
self.apply_formula = self.formula_config.get('enable', True)
self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get('enable', True)
self.lang = kwargs.get('lang', None)
self.device = kwargs.get('device', 'cpu')
logger.info(
'DocAnalysis init, this may take some times......'
)
atom_model_manager = AtomModelSingleton()
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
device=self.device,
)
# 初始化layout模型
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
),
device=self.device,
)
# 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
lang=self.lang,
)
logger.info('DocAnalysis init done!')
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from tqdm import tqdm
from mineru.utils.config_reader import get_device, get_llm_aided_config, get_formula_enable
from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.backend.pipeline.para_split import para_split
from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
from mineru.utils.block_sort import sort_blocks_by_bbox
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import ContentType
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.model_utils import clean_memory
from mineru.backend.pipeline.pipeline_magic_model import MagicModel
from mineru.utils.ocr_utils import OcrConfidence
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, txt_spans_extract
from mineru.version import __version__
from mineru.utils.hash_utils import str_md5
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
page_w, page_h = map(int, page.get_size())
magic_model = MagicModel(page_model_info, scale)
"""从magic_model对象中获取后面会用到的区块信息"""
discarded_blocks = magic_model.get_discarded()
text_blocks = magic_model.get_text_blocks()
title_blocks = magic_model.get_title_blocks()
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
img_groups = magic_model.get_imgs()
table_groups = magic_model.get_tables()
"""对image和table的区块分组"""
img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
)
"""获取所有的spans信息"""
spans = magic_model.get_all_spans()
"""某些图可能是文本块,通过简单的规则判断一下"""
if len(maybe_text_image_blocks) > 0:
for block in maybe_text_image_blocks:
span_in_block_list = []
for span in spans:
if span['type'] == 'text' and calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7:
span_in_block_list.append(span)
if len(span_in_block_list) > 0:
# span_in_block_list中所有bbox的面积之和
spans_area = sum((span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1]) for span in span_in_block_list)
# 求ocr_res_area和res的面积的比值
block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
if block_area > 0:
ratio = spans_area / block_area
if ratio > 0.25 and ocr_enable:
# 移除block的group_id
block.pop('group_id', None)
# 符合文本图的条件就把块加入到文本块列表中
text_blocks.append(block)
else:
# 如果不符合文本图的条件,就把块加回到图片块列表中
img_body_blocks.append(block)
else:
img_body_blocks.append(block)
"""将所有区块的bbox整理到一起"""
if formula_enabled:
interline_equation_blocks = []
if len(interline_equation_blocks) > 0:
for block in interline_equation_blocks:
spans.append({
"type": ContentType.INTERLINE_EQUATION,
'score': block['score'],
"bbox": block['bbox'],
})
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else:
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
if ocr_enable:
pass
else:
"""使用新版本的混合ocr方案."""
spans = txt_spans_extract(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
"""如果当前页面没有有效的bbox则跳过"""
if len(all_bboxes) == 0:
return None
"""对image/table/interline_equation截图"""
for span in spans:
if span['type'] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(
span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
)
"""span填充进block"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
"""对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans)
"""同一行被断开的titile合并"""
# merge_title_blocks(fix_blocks)
"""对block进行排序"""
sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
"""构造page_info"""
page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
return page_info
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False, formula_enabled=True):
middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__}
formula_enabled = get_formula_enable(formula_enabled)
for page_index, page_model_info in tqdm(enumerate(model_list), total=len(model_list), desc="Processing pages"):
page = pdf_doc[page_index]
image_dict = images_list[page_index]
page_info = page_model_info_to_page_info(
page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable, formula_enabled=formula_enabled
)
if page_info is None:
page_w, page_h = map(int, page.get_size())
page_info = make_page_info_dict([], page_index, page_w, page_h, [])
middle_json["pdf_info"].append(page_info)
"""后置ocr处理"""
need_ocr_list = []
img_crop_list = []
text_block_list = []
for page_info in middle_json["pdf_info"]:
for block in page_info['preproc_blocks']:
if block['type'] in ['table', 'image']:
for sub_block in block['blocks']:
if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
text_block_list.append(sub_block)
elif block['type'] in ['text', 'title']:
text_block_list.append(block)
for block in page_info['discarded_blocks']:
text_block_list.append(block)
for block in text_block_list:
for line in block['lines']:
for span in line['spans']:
if 'np_img' in span:
need_ocr_list.append(span)
img_crop_list.append(span['np_img'])
span.pop('np_img')
if len(img_crop_list) > 0:
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
assert len(ocr_res_list) == len(
need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
for index, span in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
if ocr_score > OcrConfidence.min_confidence:
span['content'] = ocr_text
span['score'] = float(f"{ocr_score:.3f}")
else:
span['content'] = ''
span['score'] = 0.0
"""分段"""
para_split(middle_json["pdf_info"])
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
"""清理内存"""
pdf_doc.close()
clean_memory(get_device())
return middle_json
def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
return_dict = {
'preproc_blocks': blocks,
'page_idx': page_id,
'page_size': [page_w, page_h],
'discarded_blocks': discarded_blocks,
}
return return_dict
\ No newline at end of file
class AtomicModel:
Layout = "layout"
MFD = "mfd"
MFR = "mfr"
OCR = "ocr"
Table = "table"
import copy
from loguru import logger
from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
from mineru.utils.language import detect_lang
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
LIST_END_FLAG = ('.', '。', ';', ';')
class ListLineTag:
IS_LIST_START_LINE = 'is_list_start_line'
IS_LIST_END_LINE = 'is_list_end_line'
def __process_blocks(blocks):
# 对所有block预处理
# 1.通过title和interline_equation将block分组
# 2.bbox边界根据line信息重置
result = []
current_group = []
for i in range(len(blocks)):
current_block = blocks[i]
# 如果当前块是 text 类型
if current_block['type'] == 'text':
current_block['bbox_fs'] = copy.deepcopy(current_block['bbox'])
if 'lines' in current_block and len(current_block['lines']) > 0:
current_block['bbox_fs'] = [
min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']]),
]
current_group.append(current_block)
# 检查下一个块是否存在
if i + 1 < len(blocks):
next_block = blocks[i + 1]
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
if next_block['type'] in ['title', 'interline_equation']:
result.append(current_group)
current_group = []
# 处理最后一个 group
if current_group:
result.append(current_group)
return result
def __is_list_or_index_block(block):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
# index block 是一种特殊的list block
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if len(block['lines']) >= 2:
first_line = block['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
block_height = block['bbox_fs'][3] - block['bbox_fs'][1]
page_weight, page_height = block['page_size']
left_close_num = 0
left_not_close_num = 0
right_not_close_num = 0
right_close_num = 0
lines_text_list = []
center_close_num = 0
external_sides_not_close_num = 0
multiple_para_flag = False
last_line = block['lines'][-1]
if page_weight == 0:
block_weight_radio = 0
else:
block_weight_radio = block_weight / page_weight
# logger.info(f"block_weight_radio: {block_weight_radio}")
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if (
first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2
and abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2
and block['bbox_fs'][2] - last_line['bbox'][2] > line_height
):
multiple_para_flag = True
block_text = ''
for line in block['lines']:
line_text = ''
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.TEXT:
line_text += span['content'].strip()
# 添加所有文本,包括空行,保持与block['lines']长度一致
lines_text_list.append(line_text)
block_text = ''.join(lines_text_list)
block_lang = detect_lang(block_text)
# logger.info(f"block_lang: {block_lang}")
for line in block['lines']:
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
if (
line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
):
external_sides_not_close_num += 1
if abs(line_mid_x - block_mid_x) < line_height / 2:
center_close_num += 1
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
left_close_num += 1
elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
left_not_close_num += 1
# 计算右侧是否顶格
if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
right_close_num += 1
else:
# 类中文没有超长单词的情况,可以用统一的阈值
if block_lang in ['zh', 'ja', 'ko']:
closed_area = 0.26 * block_weight
else:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
# block宽的阈值可以小些,block窄的阈值要大
if block_weight_radio >= 0.5:
closed_area = 0.26 * block_weight
else:
closed_area = 0.36 * block_weight
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag = False
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag = False
num_start_count = 0
num_end_count = 0
flag_end_count = 0
if len(lines_text_list) > 0:
for line_text in lines_text_list:
if len(line_text) > 0:
if line_text[-1] in LIST_END_FLAG:
flag_end_count += 1
if line_text[0].isdigit():
num_start_count += 1
if line_text[-1].isdigit():
num_end_count += 1
if (
num_start_count / len(lines_text_list) >= 0.8
or num_end_count / len(lines_text_list) >= 0.8
):
line_num_flag = True
if flag_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if (
left_close_num / len(block['lines']) >= 0.8
or right_close_num / len(block['lines']) >= 0.8
) and line_num_flag:
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.INDEX
# 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
# 补充条件block的长宽比有要求
elif (
external_sides_not_close_num >= 2
and center_close_num == len(block['lines'])
and external_sides_not_close_num / len(block['lines']) >= 0.5
and block_height / block_weight > 0.4
):
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.LIST
elif (
left_close_num >= 2
and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
and not multiple_para_flag
# and block_weight_radio > 0.27
):
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if left_close_num / len(block['lines']) > 0.8:
# 这种是每个item只有一行,且左边都贴边的短item list
if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif line_end_flag:
for i, line in enumerate(block['lines']):
if (
len(lines_text_list[i]) > 0
and lines_text_list[i][-1] in LIST_END_FLAG
):
line[ListLineTag.IS_LIST_END_LINE] = True
if i + 1 < len(block['lines']):
block['lines'][i + 1][
ListLineTag.IS_LIST_START_LINE
] = True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else:
line_start_flag = False
for i, line in enumerate(block['lines']):
if line_start_flag:
line[ListLineTag.IS_LIST_START_LINE] = True
line_start_flag = False
if (
abs(block['bbox_fs'][2] - line['bbox'][2])
> 0.1 * block_weight
):
line[ListLineTag.IS_LIST_END_LINE] = True
line_start_flag = True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
elif num_start_count >= 2 and num_start_count == flag_end_count:
for i, line in enumerate(block['lines']):
if len(lines_text_list[i]) > 0:
if lines_text_list[i][0].isdigit():
line[ListLineTag.IS_LIST_START_LINE] = True
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
else:
# 正常有缩进的list处理
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
return BlockType.LIST
else:
return BlockType.TEXT
else:
return BlockType.TEXT
def __merge_2_text_blocks(block1, block2):
if len(block1['lines']) > 0:
first_line = block1['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block1_weight = block1['bbox'][2] - block1['bbox'][0]
block2_weight = block2['bbox'][2] - block2['bbox'][0]
min_block_weight = min(block1_weight, block2_weight)
if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
last_line = block2['lines'][-1]
if len(last_line['spans']) > 0:
last_span = last_line['spans'][-1]
line_height = last_line['bbox'][3] - last_line['bbox'][1]
if len(first_line['spans']) > 0:
first_span = first_line['spans'][0]
if len(first_span['content']) > 0:
span_start_with_num = first_span['content'][0].isdigit()
span_start_with_big_char = first_span['content'][0].isupper()
if (
# 上一个block的最后一个line的右边界和block的右边界差距不超过line_height
abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height
# 上一个block的最后一个span不是以特定符号结尾
and not last_span['content'].endswith(LINE_STOP_FLAG)
# 两个block宽度差距超过2倍也不合并
and abs(block1_weight - block2_weight) < min_block_weight
# 下一个block的第一个字符是数字
and not span_start_with_num
# 下一个block的第一个字符是大写字母
and not span_start_with_big_char
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[SplitFlag.CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[SplitFlag.LINES_DELETED] = True
return block1, block2
def __merge_2_list_blocks(block1, block2):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[SplitFlag.CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[SplitFlag.LINES_DELETED] = True
return block1, block2
def __is_list_group(text_blocks_group):
# list group的特征是一个group内的所有block都满足以下条件
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
for block in text_blocks_group:
if len(block['lines']) > 3:
return False
return True
def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups:
if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list or index block
for block in text_blocks_group:
block_type = __is_list_or_index_block(block)
block['type'] = block_type
# logger.info(f"{block['type']}:{block}")
if len(text_blocks_group) > 1:
# 在合并前判断这个group 是否是一个 list group
is_list_group = __is_list_group(text_blocks_group)
# 倒序遍历
for i in range(len(text_blocks_group) - 1, -1, -1):
current_block = text_blocks_group[i]
# 检查是否有前一个块
if i - 1 >= 0:
prev_block = text_blocks_group[i - 1]
if (
current_block['type'] == 'text'
and prev_block['type'] == 'text'
and not is_list_group
):
__merge_2_text_blocks(current_block, prev_block)
elif (
current_block['type'] == BlockType.LIST
and prev_block['type'] == BlockType.LIST
) or (
current_block['type'] == BlockType.INDEX
and prev_block['type'] == BlockType.INDEX
):
__merge_2_list_blocks(current_block, prev_block)
else:
continue
def para_split(page_info_list):
all_blocks = []
for page_info in page_info_list:
blocks = copy.deepcopy(page_info['preproc_blocks'])
for block in blocks:
block['page_num'] = page_info['page_idx']
block['page_size'] = page_info['page_size']
all_blocks.extend(blocks)
__para_merge_page(all_blocks)
for page_info in page_info_list:
page_info['para_blocks'] = []
for block in all_blocks:
if block['page_num'] == page_info['page_idx']:
page_info['para_blocks'].append(block)
if __name__ == '__main__':
input_blocks = []
# 调用函数
groups = __process_blocks(input_blocks)
for group_index, group in enumerate(groups):
print(f'Group {group_index}: {group}')
import os
import time
from typing import List, Tuple
import PIL.Image
from loguru import logger
from .model_init import MineruPipelineModel
from mineru.utils.config_reader import get_device
from ...utils.pdf_classify import classify
from ...utils.pdf_image_tools import load_images_from_pdf
from ...utils.model_utils import get_vram, clean_memory
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
lang=None,
formula_enable=None,
table_enable=None,
):
key = (lang, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(
lang=lang,
formula_enable=formula_enable,
table_enable=table_enable,
)
return self._models[key]
def custom_model_init(
lang=None,
formula_enable=True,
table_enable=True,
):
model_init_start = time.time()
# 从配置文件读取model-dir和device
device = get_device()
formula_config = {"enable": formula_enable}
table_config = {"enable": table_enable}
model_input = {
'device': device,
'table_config': table_config,
'formula_config': formula_config,
'lang': lang,
}
custom_model = MineruPipelineModel(**model_input)
model_init_cost = time.time() - model_init_start
logger.info(f'model init cost: {model_init_cost}')
return custom_model
def doc_analyze(
pdf_bytes_list,
lang_list,
parse_method: str = 'auto',
formula_enable=True,
table_enable=True,
):
"""
适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,可能会增加显存使用量,
可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为100。
"""
min_batch_inference_size = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
# 收集所有页面信息
all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
all_image_lists = []
all_pdf_docs = []
ocr_enabled_list = []
for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
# 确定OCR设置
_ocr_enable = False
if parse_method == 'auto':
if classify(pdf_bytes) == 'ocr':
_ocr_enable = True
elif parse_method == 'ocr':
_ocr_enable = True
ocr_enabled_list.append(_ocr_enable)
_lang = lang_list[pdf_idx]
# 收集每个数据集中的页面
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
all_image_lists.append(images_list)
all_pdf_docs.append(pdf_doc)
for page_idx in range(len(images_list)):
img_dict = images_list[page_idx]
all_pages_info.append((
pdf_idx, page_idx,
img_dict['img_pil'], _ocr_enable, _lang,
))
# 准备批处理
images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
batch_size = min_batch_inference_size
batch_images = [
images_with_extra_info[i:i + batch_size]
for i in range(0, len(images_with_extra_info), batch_size)
]
# 执行批处理
results = []
processed_images_count = 0
for index, batch_image in enumerate(batch_images):
processed_images_count += len(batch_image)
logger.info(
f'Batch {index + 1}/{len(batch_images)}: '
f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
)
batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
results.extend(batch_results)
# 构建返回结果
infer_results = []
for _ in range(len(pdf_bytes_list)):
infer_results.append([])
for i, page_info in enumerate(all_pages_info):
pdf_idx, page_idx, pil_img, _, _ = page_info
result = results[i]
page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
page_dict = {'layout_dets': result, 'page_info': page_info_dict}
infer_results[pdf_idx].append(page_dict)
return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
def batch_image_analyze(
images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
formula_enable=True,
table_enable=True):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
from .batch_analyze import BatchAnalyze
model_manager = ModelSingleton()
batch_ratio = 1
device = get_device()
if str(device).startswith('npu'):
try:
import torch_npu
if torch_npu.npu.is_available():
torch_npu.npu.set_compile_mode(jit_compile=False)
except Exception as e:
raise RuntimeError(
"NPU is selected as device, but torch_npu is not available. "
"Please ensure that the torch_npu package is installed correctly."
) from e
if str(device).startswith('npu') or str(device).startswith('cuda'):
vram = get_vram(device)
if vram is not None:
gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
if gpu_memory >= 16:
batch_ratio = 16
elif gpu_memory >= 12:
batch_ratio = 8
elif gpu_memory >= 8:
batch_ratio = 4
elif gpu_memory >= 6:
batch_ratio = 2
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
else:
# Default batch_ratio when VRAM can't be determined
batch_ratio = 1
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
results = batch_model(images_with_extra_info)
clean_memory(get_device())
return results
\ No newline at end of file
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in
from mineru.utils.enum_class import CategoryId, ContentType
class MagicModel:
"""每个函数没有得到元素的时候返回空list."""
def __init__(self, page_model_info: dict, scale: float):
self.__page_model_info = page_model_info
self.__scale = scale
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def __fix_axis(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det in layout_dets:
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [
int(x0 / self.__scale),
int(y0 / self.__scale),
int(x1 / self.__scale),
int(y1 / self.__scale),
]
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_low_confidence(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det in layout_dets:
if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_high_iou_and_low_confidence(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det1 in layout_dets:
for layout_det2 in layout_dets:
if layout_det1 == layout_det2:
continue
if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else:
continue
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
footnotes = []
figures = []
tables = []
for obj in self.__page_model_info['layout_dets']:
if obj['category_id'] == 7:
footnotes.append(obj)
elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
count = sum([1 if v else 0 for v in flags])
if count > 1:
return float('inf')
if left or right:
l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1]
else:
l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0]
if l2 > l1 and (l2 - l1) / l1 > 0.3:
return float('inf')
return bbox_distance(bbox1, bbox2)
def __reduct_overlap(self, bboxes):
N = len(bboxes)
keep = [True] * N
for i in range(N):
for j in range(N):
if i == j:
continue
if is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance_v3(
self,
subject_category_id: int,
object_category_id: int,
):
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__page_model_info['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__page_model_info['layout_dets'],
),
)
)
)
ret = []
N, M = len(subjects), len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
OBJ_IDX_OFFSET = 10000
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
seen_idx = set()
seen_sub_idx = set()
while N > len(seen_sub_idx):
candidates = []
for idx, kind, x0, y0 in all_boxes_with_idx:
if idx in seen_idx:
continue
candidates.append((idx, kind, x0, y0))
if len(candidates) == 0:
break
left_x = min([v[2] for v in candidates])
top_y = min([v[3] for v in candidates])
candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
fst_idx, fst_kind, left_x, top_y = candidates[0]
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
nxt = None
for i in range(1, len(candidates)):
if candidates[i][1] ^ fst_kind == 1:
nxt = candidates[i]
break
if nxt is None:
break
if fst_kind == SUB_BIT_KIND:
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
else:
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
nearest_dis = float('inf')
for i in range(N):
if i in seen_idx or i == sub_idx:continue
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
if pair_dis >= 3*nearest_dis:
seen_idx.add(sub_idx)
continue
seen_idx.add(sub_idx)
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
seen_sub_idx.add(sub_idx)
ret.append(
{
'sub_bbox': {
'bbox': subjects[sub_idx]['bbox'],
'score': subjects[sub_idx]['score'],
},
'obj_bboxes': [
{'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
],
'sub_idx': sub_idx,
}
)
for i in range(len(objects)):
j = i + OBJ_IDX_OFFSET
if j in seen_idx:
continue
seen_idx.add(j)
nearest_dis, nearest_sub_idx = float('inf'), -1
for k in range(len(subjects)):
dis = bbox_distance(objects[i]['bbox'], subjects[k]['bbox'])
if dis < nearest_dis:
nearest_dis = dis
nearest_sub_idx = k
for k in range(len(subjects)):
if k != nearest_sub_idx: continue
if k in seen_sub_idx:
for kk in range(len(ret)):
if ret[kk]['sub_idx'] == k:
ret[kk]['obj_bboxes'].append({'score': objects[i]['score'], 'bbox': objects[i]['bbox']})
break
else:
ret.append(
{
'sub_bbox': {
'bbox': subjects[k]['bbox'],
'score': subjects[k]['score'],
},
'obj_bboxes': [
{'score': objects[i]['score'], 'bbox': objects[i]['bbox']}
],
'sub_idx': k,
}
)
seen_sub_idx.add(k)
seen_idx.add(k)
for i in range(len(subjects)):
if i in seen_sub_idx:
continue
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [],
'sub_idx': i,
}
)
return ret
def get_imgs(self):
with_captions = self.__tie_up_category_by_distance_v3(
3, 4
)
with_footnotes = self.__tie_up_category_by_distance_v3(
3, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
record = {
'image_body': v['sub_bbox'],
'image_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3(
5, 6
)
with_footnotes = self.__tie_up_category_by_distance_v3(
5, 7
)
ret = []
for v in with_captions:
record = {
'table_body': v['sub_bbox'],
'table_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_equations(self) -> tuple[list, list, list]: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(
CategoryId.InlineEquation, ['latex']
)
interline_equations = self.__get_blocks_by_type(
CategoryId.InterlineEquation_YOLO, ['latex']
)
interline_equations_blocks = self.__get_blocks_by_type(
CategoryId.InterlineEquation_Layout
)
return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(CategoryId.Abandon)
return blocks
def get_text_blocks(self) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(CategoryId.Text)
return blocks
def get_title_blocks(self) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(CategoryId.Title)
return blocks
def get_all_spans(self) -> list:
def remove_duplicate_spans(spans):
new_spans = []
for span in spans:
if not any(span == existing_span for existing_span in new_spans):
new_spans.append(span)
return new_spans
all_spans = []
layout_dets = self.__page_model_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets:
category_id = layout_det['category_id']
if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3:
span['type'] = ContentType.IMAGE
elif category_id == 5:
# 获取table模型结果
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
if latex:
span['latex'] = latex
elif html:
span['html'] = html
span['type'] = ContentType.TABLE
elif category_id == 13:
span['content'] = layout_det['latex']
span['type'] = ContentType.INLINE_EQUATION
elif category_id == 14:
span['content'] = layout_det['latex']
span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == 15:
span['content'] = layout_det['text']
span['type'] = ContentType.TEXT
all_spans.append(span)
return remove_duplicate_spans(all_spans)
def __get_blocks_by_type(
self, category_type: int, extra_col=None
) -> list:
if extra_col is None:
extra_col = []
blocks = []
layout_dets = self.__page_model_info.get('layout_dets', [])
for item in layout_dets:
category_id = item.get('category_id', -1)
bbox = item.get('bbox', None)
if category_id == category_type:
block = {
'bbox': bbox,
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks
import re
from loguru import logger
from mineru.utils.config_reader import get_latex_delimiter_config
from mineru.backend.pipeline.para_split import ListLineTag
from mineru.utils.enum_class import BlockType, ContentType, MakeMode
from mineru.utils.language import detect_lang
def __is_hyphen_at_line_end(line):
"""Check if a line ends with one or more letters followed by a hyphen.
Args:
line (str): The line of text to check.
Returns:
bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
"""
# Use regex to check if the line ends with one or more letters followed by a hyphen
return bool(re.search(r'[A-Za-z]+-\s*$', line))
def make_blocks_to_markdown(paras_of_layout,
mode,
img_buket_path='',
):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
elif para_type == BlockType.INTERLINE_EQUATION:
if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
continue
if para_block['lines'][0]['spans'][0].get('content', ''):
para_text = merge_para_with_text(para_block)
else:
para_text += f"![]({img_buket_path}/{para_block['lines'][0]['spans'][0]['image_path']})"
elif para_type == BlockType.IMAGE:
if mode == MakeMode.NLP_MD:
continue
elif mode == MakeMode.MM_MD:
# 检测是否存在图片脚注
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
if has_image_footnote:
for block in para_block['blocks']: # 1st.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_text += ' \n' + merge_para_with_text(block)
else:
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += ' \n' + merge_para_with_text(block)
elif para_type == BlockType.TABLE:
if mode == MakeMode.NLP_MD:
continue
elif mode == MakeMode.MM_MD:
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TABLE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
# if processed by table model
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' '
if para_text.strip() == '':
continue
else:
# page_markdown.append(para_text.strip() + ' ')
page_markdown.append(para_text.strip())
return page_markdown
def full_to_half(text: str) -> str:
"""Convert full-width characters to half-width characters using code point manipulation.
Args:
text: String containing full-width characters
Returns:
String with full-width characters converted to half-width
"""
result = []
for char in text:
code = ord(char)
# Full-width letters and numbers (FF21-FF3A for A-Z, FF41-FF5A for a-z, FF10-FF19 for 0-9)
if (0xFF21 <= code <= 0xFF3A) or (0xFF41 <= code <= 0xFF5A) or (0xFF10 <= code <= 0xFF19):
result.append(chr(code - 0xFEE0)) # Shift to ASCII range
else:
result.append(char)
return ''.join(result)
latex_delimiters_config = get_latex_delimiter_config()
default_delimiters = {
'display': {'left': '$$', 'right': '$$'},
'inline': {'left': '$', 'right': '$'}
}
delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
display_left_delimiter = delimiters['display']['left']
display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block):
block_text = ''
for line in para_block['lines']:
for span in line['spans']:
if span['type'] in [ContentType.TEXT]:
span['content'] = full_to_half(span['content'])
block_text += span['content']
block_lang = detect_lang(block_text)
para_text = ''
for i, line in enumerate(para_block['lines']):
if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
para_text += ' \n'
for j, span in enumerate(line['spans']):
span_type = span['type']
content = ''
if span_type == ContentType.TEXT:
content = escape_special_markdown_char(span['content'])
elif span_type == ContentType.INLINE_EQUATION:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.INTERLINE_EQUATION:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
content = content.strip()
if content:
langs = ['zh', 'ja', 'ko']
# logger.info(f'block_lang: {block_lang}, content: {content}')
if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔,但是如果是行内公式结尾,还是要加空格
if j == len(line['spans']) - 1 and span_type not in [ContentType.INLINE_EQUATION]:
para_text += content
else:
para_text += f'{content} '
else:
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
if j == len(line['spans'])-1 and span_type == ContentType.TEXT and __is_hyphen_at_line_end(content):
para_text += content[:-1]
else: # 西方文本语境下 content间需要空格分隔
para_text += f'{content} '
elif span_type == ContentType.INTERLINE_EQUATION:
para_text += content
else:
continue
return para_text
def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
para_type = para_block['type']
para_content = {}
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.TITLE:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
title_level = get_title_level(para_block)
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.INTERLINE_EQUATION:
if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
return None
para_content = {
'type': 'equation',
'img_path': f"{img_buket_path}/{para_block['lines'][0]['spans'][0].get('image_path', '')}",
}
if para_block['lines'][0]['spans'][0].get('content', ''):
para_content['text'] = merge_para_with_text(para_block)
para_content['text_format'] = 'latex'
elif para_type == BlockType.IMAGE:
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION:
para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.TABLE:
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
if span.get('latex', ''):
para_content['table_body'] = f"{span['latex']}"
elif span.get('html', ''):
para_content['table_body'] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TABLE_CAPTION:
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content['table_footnote'].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
img_buket_path: str = '',
):
output_content = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get('page_idx')
if not paras_of_layout:
continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = make_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.CONTENT_LIST:
for para_block in paras_of_layout:
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
if para_content:
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
elif make_mode == MakeMode.CONTENT_LIST:
return output_content
else:
logger.error(f"Unsupported make mode: {make_mode}")
return None
def get_title_level(block):
title_level = block.get('level', 1)
if title_level > 4:
title_level = 4
elif title_level < 1:
title_level = 0
return title_level
def escape_special_markdown_char(content):
"""
转义正文里对markdown语法有特殊意义的字符
"""
special_chars = ["*", "`", "~", "$"]
for char in special_chars:
content = content.replace(char, "\\" + char)
return content
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncIterable, Iterable, List, Optional, Union
DEFAULT_SYSTEM_PROMPT = (
"A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
)
DEFAULT_USER_PROMPT = "Document Parsing:"
DEFAULT_TEMPERATURE = 0.0
DEFAULT_TOP_P = 0.8
DEFAULT_TOP_K = 20
DEFAULT_REPETITION_PENALTY = 1.0
DEFAULT_PRESENCE_PENALTY = 0.0
DEFAULT_NO_REPEAT_NGRAM_SIZE = 100
DEFAULT_MAX_NEW_TOKENS = 16384
class BasePredictor(ABC):
system_prompt = DEFAULT_SYSTEM_PROMPT
def __init__(
self,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
) -> None:
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
self.presence_penalty = presence_penalty
self.no_repeat_ngram_size = no_repeat_ngram_size
self.max_new_tokens = max_new_tokens
@abstractmethod
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str: ...
@abstractmethod
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]: ...
@abstractmethod
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]: ...
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
return await asyncio.to_thread(
self.predict,
image,
prompt,
temperature,
top_p,
top_k,
repetition_penalty,
presence_penalty,
no_repeat_ngram_size,
max_new_tokens,
)
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
return await asyncio.to_thread(
self.batch_predict,
images,
prompts,
temperature,
top_p,
top_k,
repetition_penalty,
presence_penalty,
no_repeat_ngram_size,
max_new_tokens,
)
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def synced_predict():
for chunk in self.stream_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
):
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
asyncio.create_task(
asyncio.to_thread(synced_predict),
)
while True:
chunk = await queue.get()
if chunk is None:
return
assert isinstance(chunk, str)
yield chunk
def build_prompt(self, prompt: str) -> str:
if prompt.startswith("<|im_start|>"):
return prompt
if not prompt:
prompt = DEFAULT_USER_PROMPT
return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
# Modify here. We add <|box_start|> at the end of the prompt to force the model to generate bounding box.
# if "Document OCR" in prompt:
# return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n<|box_start|>"
# else:
# return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
def close(self):
pass
from io import BytesIO
from typing import Iterable, List, Optional, Union
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer, BitsAndBytesConfig
from ...model.vlm_hf_model import Mineru2QwenForCausalLM
from ...model.vlm_hf_model.image_processing_mineru2 import process_images
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .utils import load_resource
class HuggingfacePredictor(BasePredictor):
def __init__(
self,
model_path: str,
device_map="auto",
device="cuda",
torch_dtype="auto",
load_in_8bit=False,
load_in_4bit=False,
use_flash_attn=False,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
**kwargs,
):
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs["device_map"] = {"": device}
if load_in_8bit:
kwargs["load_in_8bit"] = True
elif load_in_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
else:
kwargs["torch_dtype"] = torch_dtype
if use_flash_attn:
kwargs["attn_implementation"] = "flash_attention_2"
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Mineru2QwenForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs,
)
setattr(self.model.config, "_name_or_path", model_path)
self.model.eval()
vision_tower = self.model.get_model().vision_tower
if device_map != "auto":
vision_tower.to(device=device_map, dtype=self.model.dtype)
self.image_processor = vision_tower.image_processor
self.eos_token_id = self.model.config.eos_token_id
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
**kwargs,
) -> str:
prompt = self.build_prompt(prompt)
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
do_sample = (temperature > 0.0) and (top_k > 1)
generate_kwargs = {
"repetition_penalty": repetition_penalty,
"no_repeat_ngram_size": no_repeat_ngram_size,
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
}
if do_sample:
generate_kwargs["temperature"] = temperature
generate_kwargs["top_p"] = top_p
generate_kwargs["top_k"] = top_k
if isinstance(image, str):
image = load_resource(image)
image_obj = Image.open(BytesIO(image))
image_tensor = process_images([image_obj], self.image_processor, self.model.config)
image_tensor = image_tensor[0].unsqueeze(0)
image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
image_sizes = [[*image_obj.size]]
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device=self.model.device)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
use_cache=True,
**generate_kwargs,
**kwargs,
)
# Remove the last token if it is the eos_token_id
if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
output_ids = output_ids[:, :-1]
output = self.tokenizer.batch_decode(
output_ids,
skip_special_tokens=False,
)[0].strip()
return output
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, # not supported by hf
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
**kwargs,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
outputs = []
for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
output = self.predict(
image,
prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
**kwargs,
)
outputs.append(output)
return outputs
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .sglang_client_predictor import SglangClientPredictor
hf_loaded = False
try:
from .hf_predictor import HuggingfacePredictor
hf_loaded = True
except ImportError as e:
logger.warning("hf is not installed. If you are not using transformers, you can ignore this warning.")
engine_loaded = False
try:
from sglang.srt.server_args import ServerArgs
from .sglang_engine_predictor import SglangEnginePredictor
engine_loaded = True
except Exception as e:
logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
def get_predictor(
backend: str = "sglang-client",
model_path: str | None = None,
server_url: str | None = None,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
http_timeout: int = 600,
**kwargs,
) -> BasePredictor:
start_time = time.time()
if backend == "transformers":
if not model_path:
raise ValueError("model_path must be provided for transformers backend.")
if not hf_loaded:
raise ImportError(
"transformers is not installed, so huggingface backend cannot be used. "
"If you need to use huggingface backend, please install transformers first."
)
predictor = HuggingfacePredictor(
model_path=model_path,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
**kwargs,
)
elif backend == "sglang-engine":
if not model_path:
raise ValueError("model_path must be provided for sglang-engine backend.")
if not engine_loaded:
raise ImportError(
"sglang is not installed, so sglang-engine backend cannot be used. "
"If you need to use sglang-engine backend for inference, "
"please install sglang[all]==0.4.7 or a newer version."
)
predictor = SglangEnginePredictor(
server_args=ServerArgs(model_path, **kwargs),
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
elif backend == "sglang-client":
if not server_url:
raise ValueError("server_url must be provided for sglang-client backend.")
predictor = SglangClientPredictor(
server_url=server_url,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
http_timeout=http_timeout,
)
else:
raise ValueError(f"Unsupported backend: {backend}. Supports: transformers, sglang-engine, sglang-client.")
elapsed = round(time.time() - start_time, 2)
logger.info(f"get_predictor cost: {elapsed}s")
return predictor
import asyncio
import json
import re
from base64 import b64encode
from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
import httpx
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .utils import aio_load_resource, load_resource
class SglangClientPredictor(BasePredictor):
def __init__(
self,
server_url: str,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
http_timeout: int = 600,
) -> None:
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
self.http_timeout = http_timeout
base_url = self.get_base_url(server_url)
self.check_server_health(base_url)
self.model_path = self.get_model_path(base_url)
self.server_url = f"{base_url}/generate"
@staticmethod
def get_base_url(server_url: str) -> str:
matched = re.match(r"^(https?://[^/]+)", server_url)
if not matched:
raise ValueError(f"Invalid server URL: {server_url}")
return matched.group(1)
def check_server_health(self, base_url: str):
try:
response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
except httpx.ConnectError:
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
if response.status_code != 200:
raise RuntimeError(
f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
)
def get_model_path(self, base_url: str) -> str:
try:
response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
except httpx.ConnectError:
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
if response.status_code != 200:
raise RuntimeError(
f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
)
return response.json()["model_path"]
def build_sampling_params(
self,
temperature: Optional[float],
top_p: Optional[float],
top_k: Optional[int],
repetition_penalty: Optional[float],
presence_penalty: Optional[float],
no_repeat_ngram_size: Optional[int],
max_new_tokens: Optional[int],
) -> dict:
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
return {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
def build_request_body(
self,
image: bytes,
prompt: str,
sampling_params: dict,
) -> dict:
image_base64 = b64encode(image).decode("utf-8")
return {
"text": prompt,
"image_data": image_base64,
"sampling_params": sampling_params,
"modalities": ["image"],
}
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
response_body = response.json()
return response_body["text"]
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> List[str]:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
task = self.aio_batch_predict(
images=images,
prompts=prompts,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
max_concurrency=max_concurrency,
)
if loop is not None:
return loop.run_until_complete(task)
else:
return asyncio.run(task)
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
request_body["stream"] = True
with httpx.stream(
"POST",
self.server_url,
json=request_body,
timeout=self.http_timeout,
) as response:
pos = 0
for chunk in response.iter_lines():
if not (chunk or "").startswith("data:"):
continue
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
# meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
async_client: Optional[httpx.AsyncClient] = None,
) -> str:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = await aio_load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
if async_client is None:
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
response = await client.post(self.server_url, json=request_body)
response_body = response.json()
else:
response = await async_client.post(self.server_url, json=request_body)
response_body = response.json()
return response_body["text"]
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
semaphore = asyncio.Semaphore(max_concurrency)
outputs = [""] * len(images)
async def predict_with_semaphore(
idx: int,
image: str | bytes,
prompt: str,
async_client: httpx.AsyncClient,
):
async with semaphore:
output = await self.aio_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
async_client=async_client,
)
outputs[idx] = output
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
tasks = []
for idx, (prompt, image) in enumerate(zip(prompts, images)):
tasks.append(predict_with_semaphore(idx, image, prompt, client))
await asyncio.gather(*tasks)
return outputs
async def aio_batch_predict_as_iter(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> AsyncIterable[Tuple[int, str]]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
semaphore = asyncio.Semaphore(max_concurrency)
async def predict_with_semaphore(
idx: int,
image: str | bytes,
prompt: str,
async_client: httpx.AsyncClient,
):
async with semaphore:
output = await self.aio_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
async_client=async_client,
)
return (idx, output)
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
pending: Set[asyncio.Task[Tuple[int, str]]] = set()
for idx, (prompt, image) in enumerate(zip(prompts, images)):
pending.add(
asyncio.create_task(
predict_with_semaphore(idx, image, prompt, client),
)
)
while len(pending) > 0:
done, pending = await asyncio.wait(
pending,
return_when=asyncio.FIRST_COMPLETED,
)
for task in done:
yield task.result()
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = await aio_load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
request_body["stream"] = True
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
async with client.stream(
"POST",
self.server_url,
json=request_body,
) as response:
pos = 0
async for chunk in response.aiter_lines():
if not (chunk or "").startswith("data:"):
continue
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
# meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text
from base64 import b64encode
from typing import AsyncIterable, Iterable, List, Optional, Union
from sglang.srt.server_args import ServerArgs
from ...model.vlm_sglang_model.engine import BatchEngine
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
class SglangEnginePredictor(BasePredictor):
def __init__(
self,
server_args: ServerArgs,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
) -> None:
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
self.engine = BatchEngine(server_args=server_args)
def load_image_string(self, image: str | bytes) -> str:
if not isinstance(image, (str, bytes)):
raise ValueError("Image must be a string or bytes.")
if isinstance(image, bytes):
return b64encode(image).decode("utf-8")
if image.startswith("file://"):
return image[len("file://") :]
return image
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
return self.batch_predict(
[image], # type: ignore
[prompt],
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)[0]
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
prompts = [self.build_prompt(prompt) for prompt in prompts]
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
sampling_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
image_strings = [self.load_image_string(img) for img in images]
output = self.engine.generate(
prompt=prompts,
image_data=image_strings,
sampling_params=sampling_params,
)
return [item["text"] for item in output]
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
output = await self.aio_batch_predict(
[image], # type: ignore
[prompt],
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
return output[0]
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
prompts = [self.build_prompt(prompt) for prompt in prompts]
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
sampling_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
image_strings = [self.load_image_string(img) for img in images]
output = await self.engine.async_generate(
prompt=prompts,
image_data=image_strings,
sampling_params=sampling_params,
)
ret = []
for item in output: # type: ignore
ret.append(item["text"])
return ret
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
def close(self):
self.engine.shutdown()
import re
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.version import __version__
def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
"""将token转换为页面信息"""
# 解析token,提取坐标和类型
# 假设token格式为:<|box_start|>x0 y0 x1 y1<|box_end|><|ref_start|>type<|ref_end|><|md_start|>content<|md_end|>
# 这里需要根据实际的token格式进行解析
# 提取所有完整块,每个块从<|box_start|>开始到<|md_end|>或<|im_end|>结束
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
width, height = map(int, page.get_size())
magic_model = MagicModel(token, width, height)
image_blocks = magic_model.get_image_blocks()
table_blocks = magic_model.get_table_blocks()
title_blocks = magic_model.get_title_blocks()
text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks()
all_spans = magic_model.get_all_spans()
# 对image/table/interline_equation的span截图
for span in all_spans:
if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
page_blocks = []
page_blocks.extend([*image_blocks, *table_blocks, *title_blocks, *text_blocks, *interline_equation_blocks])
# 对page_blocks根据index的值进行排序
page_blocks.sort(key=lambda x: x["index"])
page_info = {"para_blocks": page_blocks, "discarded_blocks": [], "page_size": [width, height], "page_idx": page_index}
return page_info
def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
for index, token in enumerate(token_list):
page = pdf_doc[index]
image_dict = images_list[index]
page_info = token_to_page_info(token, image_dict, page, image_writer, index)
middle_json["pdf_info"].append(page_info)
# 关闭pdf文档
pdf_doc.close()
return middle_json
if __name__ == "__main__":
output = r"<|box_start|>088 119 472 571<|box_end|><|ref_start|>image<|ref_end|><|md_start|>![]('img_url')<|md_end|>\n<|box_start|>079 582 482 608<|box_end|><|ref_start|>image_caption<|ref_end|><|md_start|>Fig. 2. (a) Schematic of the change in the FDC over time, and (b) definition of model parameters.<|md_end|>\n<|box_start|>079 624 285 638<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.2. Zero flow day analysis<|md_end|>\n<|box_start|>079 656 482 801<|box_end|><|ref_start|>text<|ref_end|><|md_start|>A notable feature of Fig. 1 is the increase in the number of zero flow days. A similar approach to Eq. (2), using an inverse sigmoidal function was employed to assess the impact of afforestation on the number of zero flow days per year \((N_{\mathrm{zero}})\). In this case, the left hand side of Eq. (2) is replaced by \(N_{\mathrm{zero}}\) and \(b\) and \(S\) are constrained to negative as \(N_{\mathrm{zero}}\) decreases as rainfall increases, and increases with plantation growth:<|md_end|>\n<|box_start|>076 813 368 853<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nN_{\mathrm{zero}}=a+b(\Delta P)+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}}{S}\right)}\n\]<|md_end|>\n<|box_start|>079 865 482 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For the average pre-treatment condition \(\Delta P=0\) and \(T=0\), \(N_{\mathrm{zero}}\) approximately equals \(a\). \(Y\) gives<|md_end|>\n<|box_start|>525 119 926 215<|box_end|><|ref_start|>text<|ref_end|><|md_start|>the magnitude of change in zero flow days due to afforestation, and \(S\) describes the shape of the response. For the average climate condition \(\Delta P=0\), \(a+Y\) becomes the number of zero flow days when the new equilibrium condition under afforestation is reached.<|md_end|>\n<|box_start|>525 240 704 253<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.3. Statistical analyses<|md_end|>\n<|box_start|>525 271 926 368<|box_end|><|ref_start|>text<|ref_end|><|md_start|>The coefficient of efficiency \((E)\) (Nash and Sutcliffe, 1970; Chiew and McMahon, 1993; Legates and McCabe, 1999) was used as the 'goodness of fit' measure to evaluate the fit between observed and predicted flow deciles (2) and zero flow days (3). \(E\) is given by:<|md_end|>\n<|box_start|>520 375 735 415<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nE=1.0-\frac{\sum_{i=1}^{N}(O_{i}-P_{i})^{2}}{\sum_{i=1}^{N}(O_{i}-\bar{O})^{2}}\n\]<|md_end|>\n<|box_start|>525 424 926 601<|box_end|><|ref_start|>text<|ref_end|><|md_start|>where \(O\) are observed data, \(P\) are predicted values, and \(\bar{O}\) is the mean for the entire period. \(E\) is unity minus the ratio of the mean square error to the variance in the observed data, and ranges from \(-\infty\) to 1.0. Higher values indicate greater agreement between observed and predicted data as per the coefficient of determination \((r^{2})\). \(E\) is used in preference to \(r^{2}\) in evaluating hydrologic modelling because it is a measure of the deviation from the 1:1 line. As \(E\) is always \(<r^{2}\) we have arbitrarily considered \(E>0.7\) to indicate adequate model fits.<|md_end|>\n<|box_start|>525 603 926 731<|box_end|><|ref_start|>text<|ref_end|><|md_start|>It is important to assess the significance of the model parameters to check the model assumptions that rainfall and forest age are driving changes in the FDC. The model (2) was split into simplified forms, where only the rainfall or time terms were included by setting \(b=0\), as shown in Eq. (5), or \(Y=0\) as shown in Eq. (6). The component models (5) and (6) were then tested against the complete model, (2).<|md_end|>\n<|box_start|>520 739 735 778<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}^{\prime}}{S}\right)}\n\]<|md_end|>\n<|box_start|>525 787 553 799<|box_end|><|ref_start|>text<|ref_end|><|md_start|>and<|md_end|>\n<|box_start|>520 807 646 825<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+b\Delta P\n\]<|md_end|>\n<|box_start|>525 833 926 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For both the flow duration curve analysis and zero flow days analysis, a \(t\)-test was then performed to test whether (5) and (6) were significantly different to (2). A critical value of \(t\) exceeding the calculated \(t\)-value<|md_end|><|im_end|>"
p_info = token_to_page_info(output)
# 将blocks 转换为json文本
import json
json_str = json.dumps(p_info, ensure_ascii=False, indent=4)
print(json_str)
import os
import re
from base64 import b64decode
import httpx
_timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
_file_exts = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".pdf")
_data_uri_regex = re.compile(r"^data:[^;,]+;base64,")
def load_resource(uri: str) -> bytes:
if uri.startswith("http://") or uri.startswith("https://"):
response = httpx.get(uri, timeout=_timeout)
return response.content
if uri.startswith("file://"):
with open(uri[len("file://") :], "rb") as file:
return file.read()
if uri.lower().endswith(_file_exts):
with open(uri, "rb") as file:
return file.read()
if re.match(_data_uri_regex, uri):
return b64decode(uri.split(",")[1])
return b64decode(uri)
async def aio_load_resource(uri: str) -> bytes:
if uri.startswith("http://") or uri.startswith("https://"):
async with httpx.AsyncClient(timeout=_timeout) as client:
response = await client.get(uri)
return response.content
if uri.startswith("file://"):
with open(uri[len("file://") :], "rb") as file:
return file.read()
if uri.lower().endswith(_file_exts):
with open(uri, "rb") as file:
return file.read()
if re.match(_data_uri_regex, uri):
return b64decode(uri.split(",")[1])
return b64decode(uri)
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from ...data.data_reader_writer import DataWriter
from mineru.utils.pdf_image_tools import load_images_from_pdf
from .base_predictor import BasePredictor
from .predictor import get_predictor
from .token_to_middle_json import result_to_middle_json
from ...utils.models_download_utils import auto_download_and_get_model_root_path
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
backend: str,
model_path: str | None,
server_url: str | None,
) -> BasePredictor:
key = (backend, model_path, server_url)
if key not in self._models:
if backend in ['transformers', 'sglang-engine'] and not model_path:
model_path = auto_download_and_get_model_root_path("/","vlm")
self._models[key] = get_predictor(
backend=backend,
model_path=model_path,
server_url=server_url,
)
return self._models[key]
def doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: BasePredictor | None = None,
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
# load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
# load_images_time = round(time.time() - load_images_start, 2)
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
# infer_start = time.time()
results = predictor.batch_predict(images=images_base64_list)
# infer_time = round(time.time() - infer_start, 2)
# logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json, results
async def aio_doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: BasePredictor | None = None,
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
load_images_time = round(time.time() - load_images_start, 2)
logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
infer_start = time.time()
results = await predictor.aio_batch_predict(images=images_base64_list)
infer_time = round(time.time() - infer_start, 2)
logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json
import re
from typing import Literal
from loguru import logger
from mineru.utils.boxbase import bbox_distance, is_in
from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
from mineru.utils.format_utils import convert_otsl_to_html
class MagicModel:
def __init__(self, token: str, width, height):
self.token = token
# 使用正则表达式查找所有块
pattern = (
r"<\|box_start\|>(.*?)<\|box_end\|><\|ref_start\|>(.*?)<\|ref_end\|><\|md_start\|>(.*?)(?:<\|md_end\|>|<\|im_end\|>)"
)
block_infos = re.findall(pattern, token, re.DOTALL)
blocks = []
self.all_spans = []
# 解析每个块
for index, block_info in enumerate(block_infos):
block_bbox = block_info[0].strip()
try:
x1, y1, x2, y2 = map(int, block_bbox.split())
x_1, y_1, x_2, y_2 = (
int(x1 * width / 1000),
int(y1 * height / 1000),
int(x2 * width / 1000),
int(y2 * height / 1000),
)
if x_2 < x_1:
x_1, x_2 = x_2, x_1
if y_2 < y_1:
y_1, y_2 = y_2, y_1
block_bbox = (x_1, y_1, x_2, y_2)
block_type = block_info[1].strip()
block_content = block_info[2].strip()
# print(f"坐标: {block_bbox}")
# print(f"类型: {block_type}")
# print(f"内容: {block_content}")
# print("-" * 50)
except Exception as e:
# 如果解析失败,可能是因为格式不正确,跳过这个块
logger.warning(f"Invalid block format: {block_info}, error: {e}")
continue
span_type = "unknown"
if block_type in [
"text",
"title",
"image_caption",
"image_footnote",
"table_caption",
"table_footnote",
"list",
"index",
]:
span_type = ContentType.TEXT
elif block_type in ["image"]:
block_type = BlockType.IMAGE_BODY
span_type = ContentType.IMAGE
elif block_type in ["table"]:
block_type = BlockType.TABLE_BODY
span_type = ContentType.TABLE
elif block_type in ["equation"]:
block_type = BlockType.INTERLINE_EQUATION
span_type = ContentType.INTERLINE_EQUATION
if span_type in ["image", "table"]:
span = {
"bbox": block_bbox,
"type": span_type,
}
if span_type == ContentType.TABLE:
if "<fcel>" in block_content or "<ecel>" in block_content:
lines = block_content.split("\n\n")
new_lines = []
for line in lines:
if "<fcel>" in line or "<ecel>" in line:
line = convert_otsl_to_html(line)
new_lines.append(line)
span["html"] = "\n\n".join(new_lines)
else:
span["html"] = block_content
elif span_type in [ContentType.INTERLINE_EQUATION]:
span = {
"bbox": block_bbox,
"type": span_type,
"content": isolated_formula_clean(block_content),
}
else:
if block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
# 生成包含文本和公式的span列表
spans = []
last_end = 0
# 查找所有公式
for match in re.finditer(r'\\\((.+?)\\\)', block_content):
start, end = match.span()
# 添加公式前的文本
if start > last_end:
text_before = block_content[last_end:start]
if text_before.strip():
spans.append({
"bbox": block_bbox,
"type": ContentType.TEXT,
"content": text_before
})
# 添加公式(去除\(和\))
formula = match.group(1)
spans.append({
"bbox": block_bbox,
"type": ContentType.INLINE_EQUATION,
"content": formula.strip()
})
last_end = end
# 添加最后一个公式后的文本
if last_end < len(block_content):
text_after = block_content[last_end:]
if text_after.strip():
spans.append({
"bbox": block_bbox,
"type": ContentType.TEXT,
"content": text_after
})
span = spans
else:
span = {
"bbox": block_bbox,
"type": span_type,
"content": block_content,
}
if isinstance(span, dict) and "bbox" in span:
self.all_spans.append(span)
line = {
"bbox": block_bbox,
"spans": [span],
}
elif isinstance(span, list):
self.all_spans.extend(span)
line = {
"bbox": block_bbox,
"spans": span,
}
else:
raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
blocks.append(
{
"bbox": block_bbox,
"type": block_type,
"lines": [line],
"index": index,
}
)
self.image_blocks = []
self.table_blocks = []
self.interline_equation_blocks = []
self.text_blocks = []
self.title_blocks = []
for block in blocks:
if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
self.image_blocks.append(block)
elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
self.table_blocks.append(block)
elif block["type"] == BlockType.INTERLINE_EQUATION:
self.interline_equation_blocks.append(block)
elif block["type"] == BlockType.TEXT:
self.text_blocks.append(block)
elif block["type"] == BlockType.TITLE:
self.title_blocks.append(block)
else:
continue
def get_image_blocks(self):
return fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
def get_table_blocks(self):
return fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
def get_title_blocks(self):
return fix_title_blocks(self.title_blocks)
def get_text_blocks(self):
return fix_text_blocks(self.text_blocks)
def get_interline_equation_blocks(self):
return self.interline_equation_blocks
def get_all_spans(self):
return self.all_spans
def isolated_formula_clean(txt):
latex = txt[:]
if latex.startswith("\\["): latex = latex[2:]
if latex.endswith("\\]"): latex = latex[:-2]
latex = latex_fix(latex.strip())
return latex
def latex_fix(latex):
# valid pairs:
# \left\{ ... \right\}
# \left( ... \right)
# \left| ... \right|
# \left\| ... \right\|
# \left[ ... \right]
LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
left_count = len(LEFT_COUNT_PATTERN.findall(latex)) # 不匹配\lefteqn等
right_count = len(RIGHT_COUNT_PATTERN.findall(latex)) # 不匹配\rightarrow
if left_count != right_count:
for _ in range(2):
# replace valid pairs
latex = re.sub(r'\\left\\\{', "{", latex) # \left\{
latex = re.sub(r"\\left\|", "|", latex) # \left|
latex = re.sub(r"\\left\\\|", "|", latex) # \left\|
latex = re.sub(r"\\left\(", "(", latex) # \left(
latex = re.sub(r"\\left\[", "[", latex) # \left[
latex = re.sub(r"\\right\\\}", "}", latex) # \right\}
latex = re.sub(r"\\right\|", "|", latex) # \right|
latex = re.sub(r"\\right\\\|", "|", latex) # \right\|
latex = re.sub(r"\\right\)", ")", latex) # \right)
latex = re.sub(r"\\right\]", "]", latex) # \right]
latex = re.sub(r"\\right\.", "", latex) # \right.
# replace invalid pairs first
latex = re.sub(r'\\left\{', "{", latex)
latex = re.sub(r'\\right\}', "}", latex) # \left{ ... \right}
latex = re.sub(r'\\left\\\(', "(", latex)
latex = re.sub(r'\\right\\\)', ")", latex) # \left\( ... \right\)
latex = re.sub(r'\\left\\\[', "[", latex)
latex = re.sub(r'\\right\\\]', "]", latex) # \left\[ ... \right\]
return latex
def __reduct_overlap(bboxes):
N = len(bboxes)
keep = [True] * N
for i in range(N):
for j in range(N):
if i == j:
continue
if is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance_v3(
blocks: list,
subject_block_type: str,
object_block_type: str,
):
subjects = __reduct_overlap(
list(
map(
lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
filter(
lambda x: x["type"] == subject_block_type,
blocks,
),
)
)
)
objects = __reduct_overlap(
list(
map(
lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
filter(
lambda x: x["type"] == object_block_type,
blocks,
),
)
)
)
ret = []
N, M = len(subjects), len(objects)
subjects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
objects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
OBJ_IDX_OFFSET = 10000
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub["bbox"][0], sub["bbox"][1]) for i, sub in enumerate(subjects)] + [
(i + OBJ_IDX_OFFSET, OBJ_BIT_KIND, obj["bbox"][0], obj["bbox"][1]) for i, obj in enumerate(objects)
]
seen_idx = set()
seen_sub_idx = set()
while N > len(seen_sub_idx):
candidates = []
for idx, kind, x0, y0 in all_boxes_with_idx:
if idx in seen_idx:
continue
candidates.append((idx, kind, x0, y0))
if len(candidates) == 0:
break
left_x = min([v[2] for v in candidates])
top_y = min([v[3] for v in candidates])
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
fst_idx, fst_kind, left_x, top_y = candidates[0]
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
nxt = None
for i in range(1, len(candidates)):
if candidates[i][1] ^ fst_kind == 1:
nxt = candidates[i]
break
if nxt is None:
break
if fst_kind == SUB_BIT_KIND:
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
else:
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
pair_dis = bbox_distance(subjects[sub_idx]["bbox"], objects[obj_idx]["bbox"])
nearest_dis = float("inf")
for i in range(N):
if i in seen_idx or i == sub_idx:
continue
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]["bbox"], objects[obj_idx]["bbox"]))
if pair_dis >= 3 * nearest_dis:
seen_idx.add(sub_idx)
continue
seen_idx.add(sub_idx)
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
seen_sub_idx.add(sub_idx)
ret.append(
{
"sub_bbox": {
"bbox": subjects[sub_idx]["bbox"],
"lines": subjects[sub_idx]["lines"],
"index": subjects[sub_idx]["index"],
},
"obj_bboxes": [
{"bbox": objects[obj_idx]["bbox"], "lines": objects[obj_idx]["lines"], "index": objects[obj_idx]["index"]}
],
"sub_idx": sub_idx,
}
)
for i in range(len(objects)):
j = i + OBJ_IDX_OFFSET
if j in seen_idx:
continue
seen_idx.add(j)
nearest_dis, nearest_sub_idx = float("inf"), -1
for k in range(len(subjects)):
dis = bbox_distance(objects[i]["bbox"], subjects[k]["bbox"])
if dis < nearest_dis:
nearest_dis = dis
nearest_sub_idx = k
for k in range(len(subjects)):
if k != nearest_sub_idx:
continue
if k in seen_sub_idx:
for kk in range(len(ret)):
if ret[kk]["sub_idx"] == k:
ret[kk]["obj_bboxes"].append(
{"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
)
break
else:
ret.append(
{
"sub_bbox": {
"bbox": subjects[k]["bbox"],
"lines": subjects[k]["lines"],
"index": subjects[k]["index"],
},
"obj_bboxes": [
{"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
],
"sub_idx": k,
}
)
seen_sub_idx.add(k)
seen_idx.add(k)
for i in range(len(subjects)):
if i in seen_sub_idx:
continue
ret.append(
{
"sub_bbox": {
"bbox": subjects[i]["bbox"],
"lines": subjects[i]["lines"],
"index": subjects[i]["index"],
},
"obj_bboxes": [],
"sub_idx": i,
}
)
return ret
def get_type_blocks(blocks, block_type: Literal["image", "table"]):
with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
ret = []
for v in with_captions:
record = {
f"{block_type}_body": v["sub_bbox"],
f"{block_type}_caption_list": v["obj_bboxes"],
}
filter_idx = v["sub_idx"]
d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
ret.append(record)
return ret
def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
need_fix_blocks = get_type_blocks(blocks, fix_type)
fixed_blocks = []
for block in need_fix_blocks:
body = block[f"{fix_type}_body"]
caption_list = block[f"{fix_type}_caption_list"]
footnote_list = block[f"{fix_type}_footnote_list"]
body["type"] = f"{fix_type}_body"
for caption in caption_list:
caption["type"] = f"{fix_type}_caption"
for footnote in footnote_list:
footnote["type"] = f"{fix_type}_footnote"
two_layer_block = {
"type": fix_type,
"bbox": body["bbox"],
"blocks": [
body,
],
"index": body["index"],
}
two_layer_block["blocks"].extend([*caption_list, *footnote_list])
fixed_blocks.append(two_layer_block)
return fixed_blocks
def fix_title_blocks(blocks):
for block in blocks:
if block["type"] == BlockType.TITLE:
title_content = merge_para_with_text(block)
title_level = count_leading_hashes(title_content)
block['level'] = title_level
for line in block['lines']:
for span in line['spans']:
span['content'] = strip_leading_hashes(span['content'])
break
break
return blocks
def count_leading_hashes(text):
match = re.match(r'^(#+)', text)
return len(match.group(1)) if match else 0
def strip_leading_hashes(text):
# 去除开头的#和紧随其后的空格
return re.sub(r'^#+\s*', '', text)
def fix_text_blocks(blocks):
i = 0
while i < len(blocks):
block = blocks[i]
last_line = block["lines"][-1]if block["lines"] else None
if last_line:
last_span = last_line["spans"][-1] if last_line["spans"] else None
if last_span and last_span['content'].endswith('<|txt_contd|>'):
last_span['content'] = last_span['content'][:-len('<|txt_contd|>')]
# 查找下一个未被清空的块
next_idx = i + 1
while next_idx < len(blocks) and blocks[next_idx].get(SplitFlag.LINES_DELETED, False):
next_idx += 1
# 如果找到下一个有效块,则合并
if next_idx < len(blocks):
next_block = blocks[next_idx]
# 将下一个块的lines扩展到当前块的lines中
block["lines"].extend(next_block["lines"])
# 清空下一个块的lines
next_block["lines"] = []
# 在下一个块中添加标志
next_block[SplitFlag.LINES_DELETED] = True
# 不增加i,继续检查当前块(现在已包含下一个块的内容)
continue
i += 1
return blocks
\ No newline at end of file
from mineru.utils.config_reader import get_latex_delimiter_config
from mineru.utils.enum_class import MakeMode, BlockType, ContentType
latex_delimiters_config = get_latex_delimiter_config()
default_delimiters = {
'display': {'left': '$$', 'right': '$$'},
'inline': {'left': '$', 'right': '$'}
}
delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
display_left_delimiter = delimiters['display']['left']
display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block):
para_text = ''
for line in para_block['lines']:
for j, span in enumerate(line['spans']):
span_type = span['type']
content = ''
if span_type == ContentType.TEXT:
content = span['content']
elif span_type == ContentType.INLINE_EQUATION:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.INTERLINE_EQUATION:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
# content = content.strip()
if content:
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
if j == len(line['spans']) - 1:
para_text += content
else:
para_text += f'{content} '
elif span_type == ContentType.INTERLINE_EQUATION:
para_text += content
return para_text
def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
page_markdown = []
for para_block in para_blocks:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
elif para_type == BlockType.IMAGE:
if make_mode == MakeMode.NLP_MD:
continue
elif make_mode == MakeMode.MM_MD:
# 检测是否存在图片脚注
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
if has_image_footnote:
for block in para_block['blocks']: # 1st.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_text += ' \n' + merge_para_with_text(block)
else:
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += ' \n' + merge_para_with_text(block)
elif para_type == BlockType.TABLE:
if make_mode == MakeMode.NLP_MD:
continue
elif make_mode == MakeMode.MM_MD:
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TABLE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
# if processed by table model
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' '
if para_text.strip() == '':
continue
else:
# page_markdown.append(para_text.strip() + ' ')
page_markdown.append(para_text.strip())
return page_markdown
def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
para_type = para_block['type']
para_content = {}
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.INTERLINE_EQUATION:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block),
'text_format': 'latex',
}
elif para_type == BlockType.IMAGE:
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION:
para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.TABLE:
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
if span.get('html', ''):
para_content['table_body'] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TABLE_CAPTION:
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content['table_footnote'].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
img_buket_path: str = '',
):
output_content = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get('page_idx')
if not paras_of_layout:
continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.CONTENT_LIST:
for para_block in paras_of_layout:
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
elif make_mode == MakeMode.CONTENT_LIST:
return output_content
return None
def get_title_level(block):
title_level = block.get('level', 1)
if title_level > 4:
title_level = 4
elif title_level < 1:
title_level = 0
return title_level
# Copyright (c) Opendatalab. All rights reserved.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment