Commit a565fa3a authored by luopl's avatar luopl
Browse files

Initial commit

parents
import html
import cv2
from loguru import logger
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from .model_init import AtomModelSingleton
from .model_list import AtomicModel
from ...utils.config_reader import get_formula_enable, get_table_enable
from ...utils.model_utils import crop_img, get_res_list_from_layout_res, clean_vram
from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
from ...utils.pdf_image_tools import get_crop_np_img
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16
OCR_DET_BASE_BATCH_SIZE = 16
TABLE_ORI_CLS_BATCH_SIZE = 16
TABLE_Wired_Wireless_CLS_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
self.batch_ratio = batch_ratio
self.formula_enable = get_formula_enable(formula_enable)
self.table_enable = get_table_enable(table_enable)
self.model_manager = model_manager
self.enable_ocr_det_batch = enable_ocr_det_batch
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
return []
images_layout_res = []
self.model = self.model_manager.get_model(
lang=None,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
atom_model_manager = AtomModelSingleton()
pil_images = [image for image, _, _ in images_with_extra_info]
np_images = [np.asarray(image) for image, _, _ in images_with_extra_info]
# doclayout_yolo
images_layout_res += self.model.layout_model.batch_predict(
pil_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
if self.formula_enable:
# 公式检测
images_mfd_res = self.model.mfd_model.batch_predict(
np_images, MFD_BASE_BATCH_SIZE
)
# 公式识别
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
np_images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
mfr_count = 0
for image_index in range(len(np_images)):
images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
# 清理显存
clean_vram(self.model.device, vram_threshold=8)
ocr_res_list_all_page = []
table_res_list_all_page = []
for index in range(len(np_images)):
_, ocr_enable, _lang = images_with_extra_info[index]
layout_res = images_layout_res[index]
np_img = np_images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
'lang':_lang,
'ocr_enable':ocr_enable,
'np_img':np_img,
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
for table_res in table_res_list:
def get_crop_table_img(scale):
crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
bbox = (int(crop_xmin / scale), int(crop_ymin / scale), int(crop_xmax / scale), int(crop_ymax / scale))
return get_crop_np_img(bbox, np_img, scale=scale)
wireless_table_img = get_crop_table_img(scale = 1)
wired_table_img = get_crop_table_img(scale = 10/3)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':wireless_table_img,
'wired_table_img':wired_table_img,
})
# 表格识别 table recognition
if self.table_enable:
# 图片旋转批量处理
img_orientation_cls_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.ImgOrientationCls,
)
try:
img_orientation_cls_model.batch_predict(table_res_list_all_page,
det_batch_size=self.batch_ratio * OCR_DET_BASE_BATCH_SIZE,
batch_size=TABLE_ORI_CLS_BATCH_SIZE)
except Exception as e:
logger.warning(
f"Image orientation classification failed: {e}, using original image"
)
# 表格分类
table_cls_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.TableCls,
)
try:
table_cls_model.batch_predict(table_res_list_all_page,
batch_size=TABLE_Wired_Wireless_CLS_BATCH_SIZE)
except Exception as e:
logger.warning(
f"Table classification failed: {e}, using default model"
)
# OCR det 过程,顺序执行
rec_img_lang_group = defaultdict(list)
det_ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
enable_merge_det_boxes=False,
)
for index, table_res_dict in enumerate(
tqdm(table_res_list_all_page, desc="Table-ocr det")
):
bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0]
# 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
for dt_box in ocr_result:
rec_img_lang_group[_lang].append(
{
"cropped_img": get_rotate_crop_image(
bgr_image, np.asarray(dt_box, dtype=np.float32)
),
"dt_box": np.asarray(dt_box, dtype=np.float32),
"table_id": index,
}
)
# OCR rec,按照语言分批处理
for _lang, rec_img_list in rec_img_lang_group.items():
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=_lang,
enable_merge_det_boxes=False,
)
cropped_img_list = [item["cropped_img"] for item in rec_img_list]
ocr_res_list = ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True, tqdm_desc=f"Table-ocr rec {_lang}")[0]
# 按照 table_id 将识别结果进行回填
for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
if table_res_list_all_page[img_dict["table_id"]].get("ocr_result"):
table_res_list_all_page[img_dict["table_id"]]["ocr_result"].append(
[img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
)
else:
table_res_list_all_page[img_dict["table_id"]]["ocr_result"] = [
[img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
]
clean_vram(self.model.device, vram_threshold=8)
# 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
wireless_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WirelessTable,
)
wireless_table_model.batch_predict(table_res_list_all_page)
# 单独拿出有线表格进行预测
wired_table_res_list = []
for table_res_dict in table_res_list_all_page:
# logger.debug(f"Table classification result: {table_res_dict["table_res"]["cls_label"]} with confidence {table_res_dict["table_res"]["cls_score"]}")
if (
(table_res_dict["table_res"]["cls_label"] == AtomicModel.WirelessTable and table_res_dict["table_res"]["cls_score"] < 0.9)
or table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable
):
wired_table_res_list.append(table_res_dict)
del table_res_dict["table_res"]["cls_label"]
del table_res_dict["table_res"]["cls_score"]
if wired_table_res_list:
for table_res_dict in tqdm(
wired_table_res_list, desc="Table-wired Predict"
):
if not table_res_dict.get("ocr_result", None):
continue
wired_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WiredTable,
lang=table_res_dict["lang"],
)
table_res_dict["table_res"]["html"] = wired_table_model.predict(
table_res_dict["wired_table_img"],
table_res_dict["ocr_result"],
table_res_dict["table_res"].get("html", None)
)
# 表格格式清理
for table_res_dict in table_res_list_all_page:
html_code = table_res_dict["table_res"].get("html", "") or ""
# 检查html_code是否包含'<table>'和'</table>'
if "<table>" in html_code and "</table>" in html_code:
# 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
start_index = html_code.find("<table>")
end_index = html_code.rfind("</table>") + len("</table>")
table_res_dict["table_res"]["html"] = html_code[start_index:end_index]
# OCR det
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.3,
lang=lang
)
# 按分辨率分组并同时完成padding
# RESOLUTION_GROUP_STRIDE = 32
RESOLUTION_GROUP_STRIDE = 64 # 定义分辨率分组的步进值
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE # 向上取整到32的倍数
normalized_w = ((w + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
# 对所有图像进行padding到统一尺寸
batch_images = []
for crop_info in group_crops:
img = crop_info[0]
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) # 增加批处理大小
# logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None and len(dt_boxes) > 0:
# 直接应用原始OCR流程中的关键处理步骤
# 1. 排序检测框
if len(dt_boxes) > 0:
dt_boxes_sorted = sorted_boxes(dt_boxes)
else:
dt_boxes_sorted = []
# 2. 合并相邻检测框
if dt_boxes_sorted:
dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
else:
dt_boxes_merged = []
# 3. 根据公式位置更新检测框(关键步骤!)
if dt_boxes_merged and adjusted_mfdetrec_res:
dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
else:
dt_boxes_final = dt_boxes_merged
# 构造OCR结果格式
ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], bgr_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
bgr_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],bgr_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# OCR rec
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
lang = layout_res_item['lang']
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(f"{ocr_score:.3f}")
if ocr_score < OcrConfidence.min_confidence:
layout_res_item['category_id'] = 16
else:
layout_res_bbox = [layout_res_item['poly'][0], layout_res_item['poly'][1],
layout_res_item['poly'][4], layout_res_item['poly'][5]]
layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号', '号', '(204'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
layout_res_item['category_id'] = 16
total_processed += len(img_crop_list)
return images_layout_res
import os
import torch
from loguru import logger
from .model_list import AtomicModel
from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
# from ...model.table.rec.RapidTable import RapidTableModel
from ...model.table.rec.slanet_plus.main import RapidTableModel
from ...model.table.rec.unet_table.main import UnetTableModel
from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import auto_download_and_get_model_root_path
def img_orientation_cls_model_init():
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang="ch_lite",
enable_merge_det_boxes=False
)
cls_model = PaddleOrientationClsModel(ocr_engine)
return cls_model
def table_cls_model_init():
return PaddleTableClsModel()
def wired_table_model_init(lang=None):
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang,
enable_merge_det_boxes=False
)
table_model = UnetTableModel(ocr_engine)
return table_model
def wireless_table_model_init(lang=None):
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang,
enable_merge_det_boxes=False
)
table_model = RapidTableModel(ocr_engine)
return table_model
def mfd_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model
def mfr_model_init(weight_dir, device='cpu'):
mfr_model = UnimernetModel(weight_dir, device)
return mfr_model
def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device)
return model
def ocr_model_init(det_db_box_thresh=0.3,
lang=None,
det_db_unclip_ratio=1.8,
enable_merge_det_boxes=True
):
if lang is not None and lang != '':
model = PytorchPaddleOCR(
det_db_box_thresh=det_db_box_thresh,
lang=lang,
use_dilation=True,
det_db_unclip_ratio=det_db_unclip_ratio,
enable_merge_det_boxes=enable_merge_det_boxes,
)
else:
model = PytorchPaddleOCR(
det_db_box_thresh=det_db_box_thresh,
use_dilation=True,
det_db_unclip_ratio=det_db_unclip_ratio,
enable_merge_det_boxes=enable_merge_det_boxes,
)
return model
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get('lang', None)
if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
key = (
atom_model_name,
lang
)
elif atom_model_name in [AtomicModel.OCR]:
key = (
atom_model_name,
kwargs.get('det_db_box_thresh', 0.3),
lang,
kwargs.get('det_db_unclip_ratio', 1.8),
kwargs.get('enable_merge_det_boxes', True)
)
else:
key = atom_model_name
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
atom_model = doclayout_yolo_model_init(
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get('mfr_weight_dir'),
kwargs.get('device')
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get('det_db_box_thresh', 0.3),
kwargs.get('lang'),
kwargs.get('det_db_unclip_ratio', 1.8),
kwargs.get('enable_merge_det_boxes', True)
)
elif model_name == AtomicModel.WirelessTable:
atom_model = wireless_table_model_init(
kwargs.get('lang'),
)
elif model_name == AtomicModel.WiredTable:
atom_model = wired_table_model_init(
kwargs.get('lang'),
)
elif model_name == AtomicModel.TableCls:
atom_model = table_cls_model_init()
elif model_name == AtomicModel.ImgOrientationCls:
atom_model = img_orientation_cls_model_init()
else:
logger.error('model name not allow')
exit(1)
if atom_model is None:
logger.error('model init failed')
exit(1)
else:
return atom_model
class MineruPipelineModel:
def __init__(self, **kwargs):
self.formula_config = kwargs.get('formula_config')
self.apply_formula = self.formula_config.get('enable', True)
self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get('enable', True)
self.lang = kwargs.get('lang', None)
self.device = kwargs.get('device', 'cpu')
logger.info(
'DocAnalysis init, this may take some times......'
)
atom_model_manager = AtomModelSingleton()
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
device=self.device,
)
# 初始化layout模型
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
),
device=self.device,
)
# 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
self.wired_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WiredTable,
lang=self.lang,
)
self.wireless_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WirelessTable,
lang=self.lang,
)
self.table_cls_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.TableCls,
)
self.img_orientation_cls_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.ImgOrientationCls,
lang=self.lang,
)
logger.info('DocAnalysis init done!')
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import os
import time
from loguru import logger
from tqdm import tqdm
from mineru.utils.config_reader import get_device, get_llm_aided_config, get_formula_enable
from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.backend.pipeline.para_split import para_split
from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
from mineru.utils.block_sort import sort_blocks_by_bbox
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import ContentType
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.model_utils import clean_memory
from mineru.backend.pipeline.pipeline_magic_model import MagicModel
from mineru.utils.ocr_utils import OcrConfidence
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, txt_spans_extract
from mineru.utils.table_merge import merge_table
from mineru.version import __version__
from mineru.utils.hash_utils import bytes_md5
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
# page_img_md5 = str_md5(image_dict["img_base64"])
page_img_md5 = bytes_md5(page_pil_img.tobytes())
page_w, page_h = map(int, page.get_size())
magic_model = MagicModel(page_model_info, scale)
"""从magic_model对象中获取后面会用到的区块信息"""
discarded_blocks = magic_model.get_discarded()
text_blocks = magic_model.get_text_blocks()
title_blocks = magic_model.get_title_blocks()
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
img_groups = magic_model.get_imgs()
table_groups = magic_model.get_tables()
"""对image和table的区块分组"""
img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
)
"""获取所有的spans信息"""
spans = magic_model.get_all_spans()
"""某些图可能是文本块,通过简单的规则判断一下"""
if len(maybe_text_image_blocks) > 0:
for block in maybe_text_image_blocks:
should_add_to_text_blocks = False
if ocr_enable:
# 找到与当前block重叠的text spans
span_in_block_list = [
span for span in spans
if span['type'] == 'text' and
calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7
]
if len(span_in_block_list) > 0:
# 计算spans总面积
spans_area = sum(
(span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1])
for span in span_in_block_list
)
# 计算block面积
block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
# 判断是否符合文本图条件
if block_area > 0 and spans_area / block_area > 0.25:
should_add_to_text_blocks = True
# 根据条件决定添加到哪个列表
if should_add_to_text_blocks:
block.pop('group_id', None) # 移除group_id
text_blocks.append(block)
else:
img_body_blocks.append(block)
"""将所有区块的bbox整理到一起"""
if formula_enabled:
interline_equation_blocks = []
if len(interline_equation_blocks) > 0:
for block in interline_equation_blocks:
spans.append({
"type": ContentType.INTERLINE_EQUATION,
'score': block['score'],
"bbox": block['bbox'],
"content": "",
})
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else:
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
if ocr_enable:
pass
else:
"""使用新版本的混合ocr方案."""
spans = txt_spans_extract(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
"""如果当前页面没有有效的bbox则跳过"""
if len(all_bboxes) == 0:
return None
"""对image/table/interline_equation截图"""
for span in spans:
if span['type'] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(
span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
)
"""span填充进block"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
"""对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans)
"""对block进行排序"""
sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
"""构造page_info"""
page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
return page_info
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False, formula_enabled=True):
middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__}
formula_enabled = get_formula_enable(formula_enabled)
for page_index, page_model_info in tqdm(enumerate(model_list), total=len(model_list), desc="Processing pages"):
page = pdf_doc[page_index]
image_dict = images_list[page_index]
page_info = page_model_info_to_page_info(
page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable, formula_enabled=formula_enabled
)
if page_info is None:
page_w, page_h = map(int, page.get_size())
page_info = make_page_info_dict([], page_index, page_w, page_h, [])
middle_json["pdf_info"].append(page_info)
"""后置ocr处理"""
need_ocr_list = []
img_crop_list = []
text_block_list = []
for page_info in middle_json["pdf_info"]:
for block in page_info['preproc_blocks']:
if block['type'] in ['table', 'image']:
for sub_block in block['blocks']:
if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
text_block_list.append(sub_block)
elif block['type'] in ['text', 'title']:
text_block_list.append(block)
for block in page_info['discarded_blocks']:
text_block_list.append(block)
for block in text_block_list:
for line in block['lines']:
for span in line['spans']:
if 'np_img' in span:
need_ocr_list.append(span)
img_crop_list.append(span['np_img'])
span.pop('np_img')
if len(img_crop_list) > 0:
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
assert len(ocr_res_list) == len(
need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
for index, span in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
if ocr_score > OcrConfidence.min_confidence:
span['content'] = ocr_text
span['score'] = float(f"{ocr_score:.3f}")
else:
span['content'] = ''
span['score'] = 0.0
"""分段"""
para_split(middle_json["pdf_info"])
"""表格跨页合并"""
merge_table(middle_json["pdf_info"])
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
"""清理内存"""
pdf_doc.close()
if os.getenv('MINERU_DONOT_CLEAN_MEM') is None and len(model_list) >= 10:
clean_memory(get_device())
return middle_json
def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
return_dict = {
'preproc_blocks': blocks,
'page_idx': page_id,
'page_size': [page_w, page_h],
'discarded_blocks': discarded_blocks,
}
return return_dict
\ No newline at end of file
class AtomicModel:
Layout = "layout"
MFD = "mfd"
MFR = "mfr"
OCR = "ocr"
WirelessTable = "wireless_table"
WiredTable = "wired_table"
TableCls = "table_cls"
ImgOrientationCls = "img_ori_cls"
import copy
from loguru import logger
from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
from mineru.utils.language import detect_lang
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
LIST_END_FLAG = ('.', '。', ';', ';')
class ListLineTag:
IS_LIST_START_LINE = 'is_list_start_line'
IS_LIST_END_LINE = 'is_list_end_line'
def __process_blocks(blocks):
# 对所有block预处理
# 1.通过title和interline_equation将block分组
# 2.bbox边界根据line信息重置
result = []
current_group = []
for i in range(len(blocks)):
current_block = blocks[i]
# 如果当前块是 text 类型
if current_block['type'] == 'text':
current_block['bbox_fs'] = copy.deepcopy(current_block['bbox'])
if 'lines' in current_block and len(current_block['lines']) > 0:
current_block['bbox_fs'] = [
min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']]),
]
current_group.append(current_block)
# 检查下一个块是否存在
if i + 1 < len(blocks):
next_block = blocks[i + 1]
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
if next_block['type'] in ['title', 'interline_equation']:
result.append(current_group)
current_group = []
# 处理最后一个 group
if current_group:
result.append(current_group)
return result
def __is_list_or_index_block(block):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
# index block 是一种特殊的list block
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if len(block['lines']) >= 2:
first_line = block['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
block_height = block['bbox_fs'][3] - block['bbox_fs'][1]
page_weight, page_height = block['page_size']
left_close_num = 0
left_not_close_num = 0
right_not_close_num = 0
right_close_num = 0
lines_text_list = []
center_close_num = 0
external_sides_not_close_num = 0
multiple_para_flag = False
last_line = block['lines'][-1]
if page_weight == 0:
block_weight_radio = 0
else:
block_weight_radio = block_weight / page_weight
# logger.info(f"block_weight_radio: {block_weight_radio}")
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if (
first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2
and abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2
and block['bbox_fs'][2] - last_line['bbox'][2] > line_height
):
multiple_para_flag = True
block_text = ''
for line in block['lines']:
line_text = ''
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.TEXT:
line_text += span['content'].strip()
# 添加所有文本,包括空行,保持与block['lines']长度一致
lines_text_list.append(line_text)
block_text = ''.join(lines_text_list)
block_lang = detect_lang(block_text)
# logger.info(f"block_lang: {block_lang}")
for line in block['lines']:
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
if (
line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
):
external_sides_not_close_num += 1
if abs(line_mid_x - block_mid_x) < line_height / 2:
center_close_num += 1
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
left_close_num += 1
elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
left_not_close_num += 1
# 计算右侧是否顶格
if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
right_close_num += 1
else:
# 类中文没有超长单词的情况,可以用统一的阈值
if block_lang in ['zh', 'ja', 'ko']:
closed_area = 0.26 * block_weight
else:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
# block宽的阈值可以小些,block窄的阈值要大
if block_weight_radio >= 0.5:
closed_area = 0.26 * block_weight
else:
closed_area = 0.36 * block_weight
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag = False
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag = False
num_start_count = 0
num_end_count = 0
flag_end_count = 0
if len(lines_text_list) > 0:
for line_text in lines_text_list:
if len(line_text) > 0:
if line_text[-1] in LIST_END_FLAG:
flag_end_count += 1
if line_text[0].isdigit():
num_start_count += 1
if line_text[-1].isdigit():
num_end_count += 1
if (
num_start_count / len(lines_text_list) >= 0.8
or num_end_count / len(lines_text_list) >= 0.8
):
line_num_flag = True
if flag_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if (
left_close_num / len(block['lines']) >= 0.8
or right_close_num / len(block['lines']) >= 0.8
) and line_num_flag:
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.INDEX
# 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
# 补充条件block的长宽比有要求
elif (
external_sides_not_close_num >= 2
and center_close_num == len(block['lines'])
and external_sides_not_close_num / len(block['lines']) >= 0.5
and block_height / block_weight > 0.4
):
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.LIST
elif (
left_close_num >= 2
and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
and not multiple_para_flag
# and block_weight_radio > 0.27
):
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if left_close_num / len(block['lines']) > 0.8:
# 这种是每个item只有一行,且左边都贴边的短item list
if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif line_end_flag:
for i, line in enumerate(block['lines']):
if (
len(lines_text_list[i]) > 0
and lines_text_list[i][-1] in LIST_END_FLAG
):
line[ListLineTag.IS_LIST_END_LINE] = True
if i + 1 < len(block['lines']):
block['lines'][i + 1][
ListLineTag.IS_LIST_START_LINE
] = True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else:
line_start_flag = False
for i, line in enumerate(block['lines']):
if line_start_flag:
line[ListLineTag.IS_LIST_START_LINE] = True
line_start_flag = False
if (
abs(block['bbox_fs'][2] - line['bbox'][2])
> 0.1 * block_weight
):
line[ListLineTag.IS_LIST_END_LINE] = True
line_start_flag = True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
elif num_start_count >= 2 and num_start_count == flag_end_count:
for i, line in enumerate(block['lines']):
if len(lines_text_list[i]) > 0:
if lines_text_list[i][0].isdigit():
line[ListLineTag.IS_LIST_START_LINE] = True
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
else:
# 正常有缩进的list处理
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
return BlockType.LIST
else:
return BlockType.TEXT
else:
return BlockType.TEXT
def __merge_2_text_blocks(block1, block2):
if len(block1['lines']) > 0:
first_line = block1['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block1_weight = block1['bbox'][2] - block1['bbox'][0]
block2_weight = block2['bbox'][2] - block2['bbox'][0]
min_block_weight = min(block1_weight, block2_weight)
if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
last_line = block2['lines'][-1]
if len(last_line['spans']) > 0:
last_span = last_line['spans'][-1]
line_height = last_line['bbox'][3] - last_line['bbox'][1]
if len(first_line['spans']) > 0:
first_span = first_line['spans'][0]
if len(first_span['content']) > 0:
span_start_with_num = first_span['content'][0].isdigit()
span_start_with_big_char = first_span['content'][0].isupper()
if (
# 上一个block的最后一个line的右边界和block的右边界差距不超过line_height
abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height
# 上一个block的最后一个span不是以特定符号结尾
and not last_span['content'].endswith(LINE_STOP_FLAG)
# 两个block宽度差距超过2倍也不合并
and abs(block1_weight - block2_weight) < min_block_weight
# 下一个block的第一个字符是数字
and not span_start_with_num
# 下一个block的第一个字符是大写字母
and not span_start_with_big_char
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[SplitFlag.CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[SplitFlag.LINES_DELETED] = True
return block1, block2
def __merge_2_list_blocks(block1, block2):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[SplitFlag.CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[SplitFlag.LINES_DELETED] = True
return block1, block2
def __is_list_group(text_blocks_group):
# list group的特征是一个group内的所有block都满足以下条件
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
for block in text_blocks_group:
if len(block['lines']) > 3:
return False
return True
def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups:
if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list or index block
for block in text_blocks_group:
block_type = __is_list_or_index_block(block)
block['type'] = block_type
# logger.info(f"{block['type']}:{block}")
if len(text_blocks_group) > 1:
# 在合并前判断这个group 是否是一个 list group
is_list_group = __is_list_group(text_blocks_group)
# 倒序遍历
for i in range(len(text_blocks_group) - 1, -1, -1):
current_block = text_blocks_group[i]
# 检查是否有前一个块
if i - 1 >= 0:
prev_block = text_blocks_group[i - 1]
if (
current_block['type'] == 'text'
and prev_block['type'] == 'text'
and not is_list_group
):
__merge_2_text_blocks(current_block, prev_block)
elif (
current_block['type'] == BlockType.LIST
and prev_block['type'] == BlockType.LIST
) or (
current_block['type'] == BlockType.INDEX
and prev_block['type'] == BlockType.INDEX
):
__merge_2_list_blocks(current_block, prev_block)
else:
continue
def para_split(page_info_list):
all_blocks = []
for page_info in page_info_list:
blocks = copy.deepcopy(page_info['preproc_blocks'])
for block in blocks:
block['page_num'] = page_info['page_idx']
block['page_size'] = page_info['page_size']
all_blocks.extend(blocks)
__para_merge_page(all_blocks)
for page_info in page_info_list:
page_info['para_blocks'] = []
for block in all_blocks:
if 'page_num' in block:
if block['page_num'] == page_info['page_idx']:
page_info['para_blocks'].append(block)
# 从block中删除不需要的page_num和page_size字段
del block['page_num']
del block['page_size']
if __name__ == '__main__':
input_blocks = []
# 调用函数
groups = __process_blocks(input_blocks)
for group_index, group in enumerate(groups):
print(f'Group {group_index}: {group}')
import os
import time
from typing import List, Tuple
from PIL import Image
from loguru import logger
from .model_init import MineruPipelineModel
from mineru.utils.config_reader import get_device
from ...utils.enum_class import ImageType
from ...utils.pdf_classify import classify
from ...utils.pdf_image_tools import load_images_from_pdf
from ...utils.model_utils import get_vram, clean_memory
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
lang=None,
formula_enable=None,
table_enable=None,
):
key = (lang, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(
lang=lang,
formula_enable=formula_enable,
table_enable=table_enable,
)
return self._models[key]
def custom_model_init(
lang=None,
formula_enable=True,
table_enable=True,
):
model_init_start = time.time()
# 从配置文件读取model-dir和device
device = get_device()
formula_config = {"enable": formula_enable}
table_config = {"enable": table_enable}
model_input = {
'device': device,
'table_config': table_config,
'formula_config': formula_config,
'lang': lang,
}
custom_model = MineruPipelineModel(**model_input)
model_init_cost = time.time() - model_init_start
logger.info(f'model init cost: {model_init_cost}')
return custom_model
def doc_analyze(
pdf_bytes_list,
lang_list,
parse_method: str = 'auto',
formula_enable=True,
table_enable=True,
):
"""
适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,更大的 MIN_BATCH_INFERENCE_SIZE会消耗更多内存,
可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为384。
"""
min_batch_inference_size = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 384))
# 收集所有页面信息
all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
all_image_lists = []
all_pdf_docs = []
ocr_enabled_list = []
for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
# 确定OCR设置
_ocr_enable = False
if parse_method == 'auto':
if classify(pdf_bytes) == 'ocr':
_ocr_enable = True
elif parse_method == 'ocr':
_ocr_enable = True
ocr_enabled_list.append(_ocr_enable)
_lang = lang_list[pdf_idx]
# 收集每个数据集中的页面
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
all_image_lists.append(images_list)
all_pdf_docs.append(pdf_doc)
for page_idx in range(len(images_list)):
img_dict = images_list[page_idx]
all_pages_info.append((
pdf_idx, page_idx,
img_dict['img_pil'], _ocr_enable, _lang,
))
# 准备批处理
images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
batch_size = min_batch_inference_size
batch_images = [
images_with_extra_info[i:i + batch_size]
for i in range(0, len(images_with_extra_info), batch_size)
]
# 执行批处理
results = []
processed_images_count = 0
for index, batch_image in enumerate(batch_images):
processed_images_count += len(batch_image)
logger.info(
f'Batch {index + 1}/{len(batch_images)}: '
f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
)
batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
results.extend(batch_results)
# 构建返回结果
infer_results = []
for _ in range(len(pdf_bytes_list)):
infer_results.append([])
for i, page_info in enumerate(all_pages_info):
pdf_idx, page_idx, pil_img, _, _ = page_info
result = results[i]
page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
page_dict = {'layout_dets': result, 'page_info': page_info_dict}
infer_results[pdf_idx].append(page_dict)
return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
def batch_image_analyze(
images_with_extra_info: List[Tuple[Image.Image, bool, str]],
formula_enable=True,
table_enable=True):
from .batch_analyze import BatchAnalyze
model_manager = ModelSingleton()
batch_ratio = 1
device = get_device()
if str(device).startswith('npu'):
try:
import torch_npu
if torch_npu.npu.is_available():
torch_npu.npu.set_compile_mode(jit_compile=False)
except Exception as e:
raise RuntimeError(
"NPU is selected as device, but torch_npu is not available. "
"Please ensure that the torch_npu package is installed correctly."
) from e
if str(device).startswith('npu') or str(device).startswith('cuda'):
vram = get_vram(device)
if vram is not None:
gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
if gpu_memory >= 16:
batch_ratio = 16
elif gpu_memory >= 12:
batch_ratio = 8
elif gpu_memory >= 8:
batch_ratio = 4
elif gpu_memory >= 6:
batch_ratio = 2
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
else:
# Default batch_ratio when VRAM can't be determined
batch_ratio = 1
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
# 检测torch的版本号
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
enable_ocr_det_batch = False
else:
enable_ocr_det_batch = True
batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
results = batch_model(images_with_extra_info)
clean_memory(get_device())
return results
\ No newline at end of file
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, get_minbox_if_overlap_by_ratio
from mineru.utils.enum_class import CategoryId, ContentType
from mineru.utils.magic_model_utils import tie_up_category_by_distance_v3, reduct_overlap
class MagicModel:
"""每个函数没有得到元素的时候返回空list."""
def __init__(self, page_model_info: dict, scale: float):
self.__page_model_info = page_model_info
self.__scale = scale
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
"""将部分tbale_footnote修正为image_footnote"""
self.__fix_footnote()
"""处理重叠的image_body和table_body"""
self.__fix_by_remove_overlap_image_table_body()
def __fix_by_remove_overlap_image_table_body(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
image_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
))
table_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
))
def add_need_remove_block(blocks):
for i in range(len(blocks)):
for j in range(i + 1, len(blocks)):
block1 = blocks[i]
block2 = blocks[j]
overlap_box = get_minbox_if_overlap_by_ratio(
block1['bbox'], block2['bbox'], 0.8
)
if overlap_box is not None:
# 判断哪个区块的面积更小,移除较小的区块
area1 = (block1['bbox'][2] - block1['bbox'][0]) * (block1['bbox'][3] - block1['bbox'][1])
area2 = (block2['bbox'][2] - block2['bbox'][0]) * (block2['bbox'][3] - block2['bbox'][1])
if area1 <= area2:
block_to_remove = block1
large_block = block2
else:
block_to_remove = block2
large_block = block1
if block_to_remove not in need_remove_list:
# 扩展大区块的边界框
x1, y1, x2, y2 = large_block['bbox']
sx1, sy1, sx2, sy2 = block_to_remove['bbox']
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block['bbox'] = [x1, y1, x2, y2]
need_remove_list.append(block_to_remove)
# 处理图像-图像重叠
add_need_remove_block(image_blocks)
# 处理表格-表格重叠
add_need_remove_block(table_blocks)
# 从布局中移除标记的区块
for need_remove in need_remove_list:
if need_remove in layout_dets:
layout_dets.remove(need_remove)
def __fix_axis(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det in layout_dets:
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [
int(x0 / self.__scale),
int(y0 / self.__scale),
int(x1 / self.__scale),
int(y1 / self.__scale),
]
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_low_confidence(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det in layout_dets:
if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_high_iou_and_low_confidence(self):
need_remove_list = []
layout_dets = list(filter(
lambda x: x['category_id'] in [
CategoryId.Title,
CategoryId.Text,
CategoryId.ImageBody,
CategoryId.ImageCaption,
CategoryId.TableBody,
CategoryId.TableCaption,
CategoryId.TableFootnote,
CategoryId.InterlineEquation_Layout,
CategoryId.InterlineEquationNumber_Layout,
], self.__page_model_info['layout_dets']
)
)
for i in range(len(layout_dets)):
for j in range(i + 1, len(layout_dets)):
layout_det1 = layout_dets[i]
layout_det2 = layout_dets[j]
if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
for need_remove in need_remove_list:
self.__page_model_info['layout_dets'].remove(need_remove)
def __fix_footnote(self):
footnotes = []
figures = []
tables = []
for obj in self.__page_model_info['layout_dets']:
if obj['category_id'] == CategoryId.TableFootnote:
footnotes.append(obj)
elif obj['category_id'] == CategoryId.ImageBody:
figures.append(obj)
elif obj['category_id'] == CategoryId.TableBody:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
count = sum([1 if v else 0 for v in flags])
if count > 1:
return float('inf')
if left or right:
l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1]
else:
l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0]
if l2 > l1 and (l2 - l1) / l1 > 0.3:
return float('inf')
return bbox_distance(bbox1, bbox2)
def __tie_up_category_by_distance_v3(self, subject_category_id, object_category_id):
# 定义获取主体和客体对象的函数
def get_subjects():
return reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__page_model_info['layout_dets'],
),
)
)
)
def get_objects():
return reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__page_model_info['layout_dets'],
),
)
)
)
# 调用通用方法
return tie_up_category_by_distance_v3(
get_subjects,
get_objects
)
def get_imgs(self):
with_captions = self.__tie_up_category_by_distance_v3(
CategoryId.ImageBody, CategoryId.ImageCaption
)
with_footnotes = self.__tie_up_category_by_distance_v3(
CategoryId.ImageBody, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
record = {
'image_body': v['sub_bbox'],
'image_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3(
CategoryId.TableBody, CategoryId.TableCaption
)
with_footnotes = self.__tie_up_category_by_distance_v3(
CategoryId.TableBody, CategoryId.TableFootnote
)
ret = []
for v in with_captions:
record = {
'table_body': v['sub_bbox'],
'table_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_equations(self) -> tuple[list, list, list]: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(
CategoryId.InlineEquation, ['latex']
)
interline_equations = self.__get_blocks_by_type(
CategoryId.InterlineEquation_YOLO, ['latex']
)
interline_equations_blocks = self.__get_blocks_by_type(
CategoryId.InterlineEquation_Layout
)
return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(CategoryId.Abandon)
return blocks
def get_text_blocks(self) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(CategoryId.Text)
return blocks
def get_title_blocks(self) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(CategoryId.Title)
return blocks
def get_all_spans(self) -> list:
def remove_duplicate_spans(spans):
new_spans = []
for span in spans:
if not any(span == existing_span for existing_span in new_spans):
new_spans.append(span)
return new_spans
all_spans = []
layout_dets = self.__page_model_info['layout_dets']
allow_category_id_list = [
CategoryId.ImageBody,
CategoryId.TableBody,
CategoryId.InlineEquation,
CategoryId.InterlineEquation_YOLO,
CategoryId.OcrText,
]
"""当成span拼接的"""
for layout_det in layout_dets:
category_id = layout_det['category_id']
if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == CategoryId.ImageBody:
span['type'] = ContentType.IMAGE
elif category_id == CategoryId.TableBody:
# 获取table模型结果
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
if latex:
span['latex'] = latex
elif html:
span['html'] = html
span['type'] = ContentType.TABLE
elif category_id == CategoryId.InlineEquation:
span['content'] = layout_det['latex']
span['type'] = ContentType.INLINE_EQUATION
elif category_id == CategoryId.InterlineEquation_YOLO:
span['content'] = layout_det['latex']
span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == CategoryId.OcrText:
span['content'] = layout_det['text']
span['type'] = ContentType.TEXT
all_spans.append(span)
return remove_duplicate_spans(all_spans)
def __get_blocks_by_type(
self, category_type: int, extra_col=None
) -> list:
if extra_col is None:
extra_col = []
blocks = []
layout_dets = self.__page_model_info.get('layout_dets', [])
for item in layout_dets:
category_id = item.get('category_id', -1)
bbox = item.get('bbox', None)
if category_id == category_type:
block = {
'bbox': bbox,
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks
\ No newline at end of file
import re
from loguru import logger
from mineru.utils.config_reader import get_latex_delimiter_config
from mineru.backend.pipeline.para_split import ListLineTag
from mineru.utils.enum_class import BlockType, ContentType, MakeMode
from mineru.utils.language import detect_lang
def __is_hyphen_at_line_end(line):
"""Check if a line ends with one or more letters followed by a hyphen.
Args:
line (str): The line of text to check.
Returns:
bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
"""
# Use regex to check if the line ends with one or more letters followed by a hyphen
return bool(re.search(r'[A-Za-z]+-\s*$', line))
def make_blocks_to_markdown(paras_of_layout,
mode,
img_buket_path='',
):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
elif para_type == BlockType.INTERLINE_EQUATION:
if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
continue
if para_block['lines'][0]['spans'][0].get('content', ''):
para_text = merge_para_with_text(para_block)
else:
para_text += f"![]({img_buket_path}/{para_block['lines'][0]['spans'][0]['image_path']})"
elif para_type == BlockType.IMAGE:
if mode == MakeMode.NLP_MD:
continue
elif mode == MakeMode.MM_MD:
# 检测是否存在图片脚注
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
if has_image_footnote:
for block in para_block['blocks']: # 1st.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_text += ' \n' + merge_para_with_text(block)
else:
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += ' \n' + merge_para_with_text(block)
elif para_type == BlockType.TABLE:
if mode == MakeMode.NLP_MD:
continue
elif mode == MakeMode.MM_MD:
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TABLE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
# if processed by table model
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' '
if para_text.strip() == '':
continue
else:
# page_markdown.append(para_text.strip() + ' ')
page_markdown.append(para_text.strip())
return page_markdown
def full_to_half(text: str) -> str:
"""Convert full-width characters to half-width characters using code point manipulation.
Args:
text: String containing full-width characters
Returns:
String with full-width characters converted to half-width
"""
result = []
for char in text:
code = ord(char)
# Full-width letters and numbers (FF21-FF3A for A-Z, FF41-FF5A for a-z, FF10-FF19 for 0-9)
if (0xFF21 <= code <= 0xFF3A) or (0xFF41 <= code <= 0xFF5A) or (0xFF10 <= code <= 0xFF19):
result.append(chr(code - 0xFEE0)) # Shift to ASCII range
else:
result.append(char)
return ''.join(result)
latex_delimiters_config = get_latex_delimiter_config()
default_delimiters = {
'display': {'left': '$$', 'right': '$$'},
'inline': {'left': '$', 'right': '$'}
}
delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
display_left_delimiter = delimiters['display']['left']
display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block):
block_text = ''
for line in para_block['lines']:
for span in line['spans']:
if span['type'] in [ContentType.TEXT]:
span['content'] = full_to_half(span['content'])
block_text += span['content']
block_lang = detect_lang(block_text)
para_text = ''
for i, line in enumerate(para_block['lines']):
if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
para_text += ' \n'
for j, span in enumerate(line['spans']):
span_type = span['type']
content = ''
if span_type == ContentType.TEXT:
content = escape_special_markdown_char(span['content'])
elif span_type == ContentType.INLINE_EQUATION:
if span.get('content', ''):
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.INTERLINE_EQUATION:
if span.get('content', ''):
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
content = content.strip()
if content:
langs = ['zh', 'ja', 'ko']
# logger.info(f'block_lang: {block_lang}, content: {content}')
if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔,但是如果是行内公式结尾,还是要加空格
if j == len(line['spans']) - 1 and span_type not in [ContentType.INLINE_EQUATION]:
para_text += content
else:
para_text += f'{content} '
else:
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
if j == len(line['spans'])-1 and span_type == ContentType.TEXT and __is_hyphen_at_line_end(content):
para_text += content[:-1]
else: # 西方文本语境下 content间需要空格分隔
para_text += f'{content} '
elif span_type == ContentType.INTERLINE_EQUATION:
para_text += content
else:
continue
return para_text
def make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size):
para_type = para_block['type']
para_content = {}
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_content = {
'type': ContentType.TEXT,
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.TITLE:
para_content = {
'type': ContentType.TEXT,
'text': merge_para_with_text(para_block),
}
title_level = get_title_level(para_block)
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.INTERLINE_EQUATION:
if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
return None
para_content = {
'type': ContentType.EQUATION,
'img_path': f"{img_buket_path}/{para_block['lines'][0]['spans'][0].get('image_path', '')}",
}
if para_block['lines'][0]['spans'][0].get('content', ''):
para_content['text'] = merge_para_with_text(para_block)
para_content['text_format'] = 'latex'
elif para_type == BlockType.IMAGE:
para_content = {'type': ContentType.IMAGE, 'img_path': '', BlockType.IMAGE_CAPTION: [], BlockType.IMAGE_FOOTNOTE: []}
for block in para_block['blocks']:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION:
para_content[BlockType.IMAGE_CAPTION].append(merge_para_with_text(block))
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content[BlockType.IMAGE_FOOTNOTE].append(merge_para_with_text(block))
elif para_type == BlockType.TABLE:
para_content = {'type': ContentType.TABLE, 'img_path': '', BlockType.TABLE_CAPTION: [], BlockType.TABLE_FOOTNOTE: []}
for block in para_block['blocks']:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
if span.get('html', ''):
para_content[BlockType.TABLE_BODY] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TABLE_CAPTION:
para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
page_width, page_height = page_size
para_bbox = para_block.get('bbox')
if para_bbox:
x0, y0, x1, y1 = para_bbox
para_content['bbox'] = [
int(x0 * 1000 / page_width),
int(y0 * 1000 / page_height),
int(x1 * 1000 / page_width),
int(y1 * 1000 / page_height),
]
para_content['page_idx'] = page_idx
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
img_buket_path: str = '',
):
output_content = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get('page_idx')
page_size = page_info.get('page_size')
if not paras_of_layout:
continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = make_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.CONTENT_LIST:
for para_block in paras_of_layout:
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size)
if para_content:
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
elif make_mode == MakeMode.CONTENT_LIST:
return output_content
else:
logger.error(f"Unsupported make mode: {make_mode}")
return None
def get_title_level(block):
title_level = block.get('level', 1)
if title_level > 4:
title_level = 4
elif title_level < 1:
title_level = 0
return title_level
def escape_special_markdown_char(content):
"""
转义正文里对markdown语法有特殊意义的字符
"""
special_chars = ["*", "`", "~", "$"]
for char in special_chars:
content = content.replace(char, "\\" + char)
return content
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import os
import time
import cv2
import numpy as np
from loguru import logger
from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.utils.config_reader import get_table_enable, get_llm_aided_config
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import ContentType
from mineru.utils.hash_utils import bytes_md5
from mineru.utils.pdf_image_tools import get_crop_img
from mineru.utils.table_merge import merge_table
from mineru.version import __version__
heading_level_import_success = False
llm_aided_config = get_llm_aided_config()
if llm_aided_config:
title_aided_config = llm_aided_config.get('title_aided', {})
if title_aided_config.get('enable', False):
try:
from mineru.utils.llm_aided import llm_aided_title
from mineru.backend.pipeline.model_init import AtomModelSingleton
heading_level_import_success = True
except Exception as e:
logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
"please execute `pip install mineru[core]` to install the required packages.")
def blocks_to_page_info(page_blocks, image_dict, page, image_writer, page_index) -> dict:
"""将blocks转换为页面信息"""
scale = image_dict["scale"]
# page_pil_img = image_dict["img_pil"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = bytes_md5(page_pil_img.tobytes())
width, height = map(int, page.get_size())
magic_model = MagicModel(page_blocks, width, height)
image_blocks = magic_model.get_image_blocks()
table_blocks = magic_model.get_table_blocks()
title_blocks = magic_model.get_title_blocks()
discarded_blocks = magic_model.get_discarded_blocks()
code_blocks = magic_model.get_code_blocks()
ref_text_blocks = magic_model.get_ref_text_blocks()
phonetic_blocks = magic_model.get_phonetic_blocks()
list_blocks = magic_model.get_list_blocks()
# 如果有标题优化需求,则对title_blocks截图det
if heading_level_import_success:
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang='ch_lite'
)
for title_block in title_blocks:
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
title_np_img = np.array(title_pil_img)
# 给title_pil_img添加上下左右各50像素白边padding
title_np_img = cv2.copyMakeBorder(
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
)
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
if len(ocr_det_res) > 0:
# 计算所有res的平均高度
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_block['line_avg_height'] = round(avg_height/scale)
text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks()
all_spans = magic_model.get_all_spans()
# 对image/table/interline_equation的span截图
for span in all_spans:
if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
page_blocks = []
page_blocks.extend([
*image_blocks,
*table_blocks,
*code_blocks,
*ref_text_blocks,
*phonetic_blocks,
*title_blocks,
*text_blocks,
*interline_equation_blocks,
*list_blocks,
])
# 对page_blocks根据index的值进行排序
page_blocks.sort(key=lambda x: x["index"])
page_info = {"para_blocks": page_blocks, "discarded_blocks": discarded_blocks, "page_size": [width, height], "page_idx": page_index}
return page_info
def result_to_middle_json(model_output_blocks_list, images_list, pdf_doc, image_writer):
middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
for index, page_blocks in enumerate(model_output_blocks_list):
page = pdf_doc[index]
image_dict = images_list[index]
page_info = blocks_to_page_info(page_blocks, image_dict, page, image_writer, index)
middle_json["pdf_info"].append(page_info)
"""表格跨页合并"""
table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
if table_enable:
merge_table(middle_json["pdf_info"])
"""llm优化标题分级"""
if heading_level_import_success:
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
# 关闭pdf文档
pdf_doc.close()
return middle_json
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import os
import time
from loguru import logger
from .model_output_to_middle_json import result_to_middle_json
from ...data.data_reader_writer import DataWriter
from mineru.utils.pdf_image_tools import load_images_from_pdf
from ...utils.config_reader import get_device
from ...utils.enum_class import ImageType
from ...utils.model_utils import get_vram
from ...utils.models_download_utils import auto_download_and_get_model_root_path
from mineru_vl_utils import MinerUClient
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
backend: str,
model_path: str | None,
server_url: str | None,
**kwargs,
) -> MinerUClient:
key = (backend, model_path, server_url)
if key not in self._models:
start_time = time.time()
model = None
processor = None
vllm_llm = None
vllm_async_llm = None
batch_size = 0
if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
model_path = auto_download_and_get_model_root_path("/","vlm")
if backend == "transformers":
try:
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
)
from transformers import __version__ as transformers_version
except ImportError:
raise ImportError("Please install transformers to use the transformers backend.")
from packaging import version
if version.parse(transformers_version) >= version.parse("4.56.0"):
dtype_key = "dtype"
else:
dtype_key = "torch_dtype"
device = get_device()
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map={"": device},
**{dtype_key: "auto"}, # type: ignore
)
processor = AutoProcessor.from_pretrained(
model_path,
use_fast=True,
)
try:
vram = get_vram(device)
if vram is not None:
gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
if gpu_memory >= 16:
batch_size = 8
elif gpu_memory >= 8:
batch_size = 4
else:
batch_size = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
else:
# Default batch_ratio when VRAM can't be determined
batch_size = 1
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
except Exception as e:
logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
batch_size = 1
elif backend == "vllm-engine":
try:
import vllm
except ImportError:
raise ImportError("Please install vllm to use the vllm-engine backend.")
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = 0.5
if "model" not in kwargs:
kwargs["model"] = model_path
# 使用kwargs为 vllm初始化参数
vllm_llm = vllm.LLM(**kwargs)
elif backend == "vllm-async-engine":
try:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
except ImportError:
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
if "gpu_memory_utilization" not in kwargs:
kwargs["gpu_memory_utilization"] = 0.5
if "model" not in kwargs:
kwargs["model"] = model_path
# 使用kwargs为 vllm初始化参数
vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
self._models[key] = MinerUClient(
backend=backend,
model=model,
processor=processor,
vllm_llm=vllm_llm,
vllm_async_llm=vllm_async_llm,
server_url=server_url,
batch_size=batch_size,
)
elapsed = round(time.time() - start_time, 2)
logger.info(f"get {backend} predictor cost: {elapsed}s")
return self._models[key]
def doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: MinerUClient | None = None,
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
**kwargs,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
# load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
# load_images_time = round(time.time() - load_images_start, 2)
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
# infer_start = time.time()
results = predictor.batch_two_step_extract(images=images_pil_list)
# infer_time = round(time.time() - infer_start, 2)
# logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json, results
async def aio_doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: MinerUClient | None = None,
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
**kwargs,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
# load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
# load_images_time = round(time.time() - load_images_start, 2)
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
# infer_start = time.time()
results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
# infer_time = round(time.time() - infer_start, 2)
# logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json, results
import re
from typing import Literal
from loguru import logger
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.enum_class import ContentType, BlockType
from mineru.utils.guess_suffix_or_lang import guess_language_by_text
from mineru.utils.magic_model_utils import reduct_overlap, tie_up_category_by_distance_v3
class MagicModel:
def __init__(self, page_blocks: list, width, height):
self.page_blocks = page_blocks
blocks = []
self.all_spans = []
# 解析每个块
for index, block_info in enumerate(page_blocks):
block_bbox = block_info["bbox"]
try:
x1, y1, x2, y2 = block_bbox
x_1, y_1, x_2, y_2 = (
int(x1 * width),
int(y1 * height),
int(x2 * width),
int(y2 * height),
)
if x_2 < x_1:
x_1, x_2 = x_2, x_1
if y_2 < y_1:
y_1, y_2 = y_2, y_1
block_bbox = (x_1, y_1, x_2, y_2)
block_type = block_info["type"]
block_content = block_info["content"]
block_angle = block_info["angle"]
# print(f"坐标: {block_bbox}")
# print(f"类型: {block_type}")
# print(f"内容: {block_content}")
# print("-" * 50)
except Exception as e:
# 如果解析失败,可能是因为格式不正确,跳过这个块
logger.warning(f"Invalid block format: {block_info}, error: {e}")
continue
span_type = "unknown"
line_type = None
guess_lang = None
if block_type in [
"text",
"title",
"image_caption",
"image_footnote",
"table_caption",
"table_footnote",
"code_caption",
"ref_text",
"phonetic",
"header",
"footer",
"page_number",
"aside_text",
"page_footnote",
"list"
]:
span_type = ContentType.TEXT
elif block_type in ["image"]:
block_type = BlockType.IMAGE_BODY
span_type = ContentType.IMAGE
elif block_type in ["table"]:
block_type = BlockType.TABLE_BODY
span_type = ContentType.TABLE
elif block_type in ["code", "algorithm"]:
block_content = code_content_clean(block_content)
line_type = block_type
block_type = BlockType.CODE_BODY
span_type = ContentType.TEXT
guess_lang = guess_language_by_text(block_content)
elif block_type in ["equation"]:
block_type = BlockType.INTERLINE_EQUATION
span_type = ContentType.INTERLINE_EQUATION
if span_type in ["image", "table"]:
span = {
"bbox": block_bbox,
"type": span_type,
}
if span_type == ContentType.TABLE:
span["html"] = block_content
elif span_type in [ContentType.INTERLINE_EQUATION]:
span = {
"bbox": block_bbox,
"type": span_type,
"content": isolated_formula_clean(block_content),
}
else:
if block_content:
block_content = clean_content(block_content)
if block_content and block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
# 生成包含文本和公式的span列表
spans = []
last_end = 0
# 查找所有公式
for match in re.finditer(r'\\\((.+?)\\\)', block_content):
start, end = match.span()
# 添加公式前的文本
if start > last_end:
text_before = block_content[last_end:start]
if text_before.strip():
spans.append({
"bbox": block_bbox,
"type": ContentType.TEXT,
"content": text_before
})
# 添加公式(去除\(和\))
formula = match.group(1)
spans.append({
"bbox": block_bbox,
"type": ContentType.INLINE_EQUATION,
"content": formula.strip()
})
last_end = end
# 添加最后一个公式后的文本
if last_end < len(block_content):
text_after = block_content[last_end:]
if text_after.strip():
spans.append({
"bbox": block_bbox,
"type": ContentType.TEXT,
"content": text_after
})
span = spans
else:
span = {
"bbox": block_bbox,
"type": span_type,
"content": block_content,
}
# 处理span类型并添加到all_spans
if isinstance(span, dict) and "bbox" in span:
self.all_spans.append(span)
spans = [span]
elif isinstance(span, list):
self.all_spans.extend(span)
spans = span
else:
raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
# 构造line对象
if block_type in [BlockType.CODE_BODY]:
line = {"bbox": block_bbox, "spans": spans, "extra": {"type": line_type, "guess_lang": guess_lang}}
else:
line = {"bbox": block_bbox, "spans": spans}
blocks.append(
{
"bbox": block_bbox,
"type": block_type,
"angle": block_angle,
"lines": [line],
"index": index,
}
)
self.image_blocks = []
self.table_blocks = []
self.interline_equation_blocks = []
self.text_blocks = []
self.title_blocks = []
self.code_blocks = []
self.discarded_blocks = []
self.ref_text_blocks = []
self.phonetic_blocks = []
self.list_blocks = []
for block in blocks:
if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
self.image_blocks.append(block)
elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
self.table_blocks.append(block)
elif block["type"] in [BlockType.CODE_BODY, BlockType.CODE_CAPTION]:
self.code_blocks.append(block)
elif block["type"] == BlockType.INTERLINE_EQUATION:
self.interline_equation_blocks.append(block)
elif block["type"] == BlockType.TEXT:
self.text_blocks.append(block)
elif block["type"] == BlockType.TITLE:
self.title_blocks.append(block)
elif block["type"] in [BlockType.REF_TEXT]:
self.ref_text_blocks.append(block)
elif block["type"] in [BlockType.PHONETIC]:
self.phonetic_blocks.append(block)
elif block["type"] in [BlockType.HEADER, BlockType.FOOTER, BlockType.PAGE_NUMBER, BlockType.ASIDE_TEXT, BlockType.PAGE_FOOTNOTE]:
self.discarded_blocks.append(block)
elif block["type"] == BlockType.LIST:
self.list_blocks.append(block)
else:
continue
self.list_blocks, self.text_blocks, self.ref_text_blocks = fix_list_blocks(self.list_blocks, self.text_blocks, self.ref_text_blocks)
self.image_blocks, not_include_image_blocks = fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
self.table_blocks, not_include_table_blocks = fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
self.code_blocks, not_include_code_blocks = fix_two_layer_blocks(self.code_blocks, BlockType.CODE)
for code_block in self.code_blocks:
for block in code_block['blocks']:
if block['type'] == BlockType.CODE_BODY:
if len(block["lines"]) > 0:
line = block["lines"][0]
code_block["sub_type"] = line["extra"]["type"]
if code_block["sub_type"] in ["code"]:
code_block["guess_lang"] = line["extra"]["guess_lang"]
del line["extra"]
else:
code_block["sub_type"] = "code"
code_block["guess_lang"] = "txt"
for block in not_include_image_blocks + not_include_table_blocks + not_include_code_blocks:
block["type"] = BlockType.TEXT
self.text_blocks.append(block)
def get_list_blocks(self):
return self.list_blocks
def get_image_blocks(self):
return self.image_blocks
def get_table_blocks(self):
return self.table_blocks
def get_code_blocks(self):
return self.code_blocks
def get_ref_text_blocks(self):
return self.ref_text_blocks
def get_phonetic_blocks(self):
return self.phonetic_blocks
def get_title_blocks(self):
return self.title_blocks
def get_text_blocks(self):
return self.text_blocks
def get_interline_equation_blocks(self):
return self.interline_equation_blocks
def get_discarded_blocks(self):
return self.discarded_blocks
def get_all_spans(self):
return self.all_spans
def isolated_formula_clean(txt):
latex = txt[:]
if latex.startswith("\\["): latex = latex[2:]
if latex.endswith("\\]"): latex = latex[:-2]
latex = latex.strip()
return latex
def code_content_clean(content):
"""清理代码内容,移除Markdown代码块的开始和结束标记"""
if not content:
return ""
lines = content.splitlines()
start_idx = 0
end_idx = len(lines)
# 处理开头的三个反引号
if lines and lines[0].startswith("```"):
start_idx = 1
# 处理结尾的三个反引号
if lines and end_idx > start_idx and lines[end_idx - 1].strip() == "```":
end_idx -= 1
# 只有在有内容时才进行join操作
if start_idx < end_idx:
return "\n".join(lines[start_idx:end_idx]).strip()
return ""
def clean_content(content):
if content and content.count("\\[") == content.count("\\]") and content.count("\\[") > 0:
# Function to handle each match
def replace_pattern(match):
# Extract content between \[ and \]
inner_content = match.group(1)
return f"[{inner_content}]"
# Find all patterns of \[x\] and apply replacement
pattern = r'\\\[(.*?)\\\]'
content = re.sub(pattern, replace_pattern, content)
return content
def __tie_up_category_by_distance_v3(blocks, subject_block_type, object_block_type):
# 定义获取主体和客体对象的函数
def get_subjects():
return reduct_overlap(
list(
map(
lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle":x["angle"]},
filter(
lambda x: x["type"] == subject_block_type,
blocks,
),
)
)
)
def get_objects():
return reduct_overlap(
list(
map(
lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle":x["angle"]},
filter(
lambda x: x["type"] == object_block_type,
blocks,
),
)
)
)
# 调用通用方法
return tie_up_category_by_distance_v3(
get_subjects,
get_objects
)
def get_type_blocks(blocks, block_type: Literal["image", "table", "code"]):
with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
ret = []
for v in with_captions:
record = {
f"{block_type}_body": v["sub_bbox"],
f"{block_type}_caption_list": v["obj_bboxes"],
}
filter_idx = v["sub_idx"]
d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
ret.append(record)
return ret
def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table", "code"]):
need_fix_blocks = get_type_blocks(blocks, fix_type)
fixed_blocks = []
not_include_blocks = []
processed_indices = set()
# 处理需要组织成two_layer结构的blocks
for block in need_fix_blocks:
body = block[f"{fix_type}_body"]
caption_list = block[f"{fix_type}_caption_list"]
footnote_list = block[f"{fix_type}_footnote_list"]
body["type"] = f"{fix_type}_body"
for caption in caption_list:
caption["type"] = f"{fix_type}_caption"
processed_indices.add(caption["index"])
for footnote in footnote_list:
footnote["type"] = f"{fix_type}_footnote"
processed_indices.add(footnote["index"])
processed_indices.add(body["index"])
two_layer_block = {
"type": fix_type,
"bbox": body["bbox"],
"blocks": [
body,
],
"index": body["index"],
}
two_layer_block["blocks"].extend([*caption_list, *footnote_list])
fixed_blocks.append(two_layer_block)
# 添加未处理的blocks
for block in blocks:
if block["index"] not in processed_indices:
# 直接添加未处理的block
not_include_blocks.append(block)
return fixed_blocks, not_include_blocks
def fix_list_blocks(list_blocks, text_blocks, ref_text_blocks):
for list_block in list_blocks:
list_block["blocks"] = []
if "lines" in list_block:
del list_block["lines"]
temp_text_blocks = text_blocks + ref_text_blocks
need_remove_blocks = []
for block in temp_text_blocks:
for list_block in list_blocks:
if calculate_overlap_area_in_bbox1_area_ratio(block["bbox"], list_block["bbox"]) >= 0.8:
list_block["blocks"].append(block)
need_remove_blocks.append(block)
break
for block in need_remove_blocks:
if block in text_blocks:
text_blocks.remove(block)
elif block in ref_text_blocks:
ref_text_blocks.remove(block)
# 移除blocks为空的list_block
list_blocks = [lb for lb in list_blocks if lb["blocks"]]
for list_block in list_blocks:
# 统计list_block["blocks"]中所有block的type,用众数作为list_block的sub_type
type_count = {}
line_content = []
for sub_block in list_block["blocks"]:
sub_block_type = sub_block["type"]
if sub_block_type not in type_count:
type_count[sub_block_type] = 0
type_count[sub_block_type] += 1
if type_count:
list_block["sub_type"] = max(type_count, key=type_count.get)
else:
list_block["sub_type"] = "unknown"
return list_blocks, text_blocks, ref_text_blocks
\ No newline at end of file
import os
from mineru.utils.config_reader import get_latex_delimiter_config, get_formula_enable, get_table_enable
from mineru.utils.enum_class import MakeMode, BlockType, ContentType
latex_delimiters_config = get_latex_delimiter_config()
default_delimiters = {
'display': {'left': '$$', 'right': '$$'},
'inline': {'left': '$', 'right': '$'}
}
delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
display_left_delimiter = delimiters['display']['left']
display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block, formula_enable=True, img_buket_path=''):
para_text = ''
for line in para_block['lines']:
for j, span in enumerate(line['spans']):
span_type = span['type']
content = ''
if span_type == ContentType.TEXT:
content = span['content']
elif span_type == ContentType.INLINE_EQUATION:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.INTERLINE_EQUATION:
if formula_enable:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
else:
if span.get('image_path', ''):
content = f"![]({img_buket_path}/{span['image_path']})"
# content = content.strip()
if content:
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
if j == len(line['spans']) - 1:
para_text += content
else:
para_text += f'{content} '
elif span_type == ContentType.INTERLINE_EQUATION:
para_text += content
return para_text
def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable, img_buket_path=''):
page_markdown = []
for para_block in para_blocks:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.INTERLINE_EQUATION, BlockType.PHONETIC, BlockType.REF_TEXT]:
para_text = merge_para_with_text(para_block, formula_enable=formula_enable, img_buket_path=img_buket_path)
elif para_type == BlockType.LIST:
for block in para_block['blocks']:
item_text = merge_para_with_text(block, formula_enable=formula_enable, img_buket_path=img_buket_path)
para_text += f"{item_text}\n"
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
elif para_type == BlockType.IMAGE:
if make_mode == MakeMode.NLP_MD:
continue
elif make_mode == MakeMode.MM_MD:
# 检测是否存在图片脚注
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
if has_image_footnote:
for block in para_block['blocks']: # 1st.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_text += ' \n' + merge_para_with_text(block)
else:
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += ' \n' + merge_para_with_text(block)
elif para_type == BlockType.TABLE:
if make_mode == MakeMode.NLP_MD:
continue
elif make_mode == MakeMode.MM_MD:
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TABLE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
# if processed by table model
if table_enable:
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
else:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' '
elif para_type == BlockType.CODE:
sub_type = para_block["sub_type"]
for block in para_block['blocks']: # 1st.拼code_caption
if block['type'] == BlockType.CODE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼code_body
if block['type'] == BlockType.CODE_BODY:
if sub_type == BlockType.CODE:
guess_lang = para_block["guess_lang"]
para_text += f"```{guess_lang}\n{merge_para_with_text(block)}\n```"
elif sub_type == BlockType.ALGORITHM:
para_text += merge_para_with_text(block)
if para_text.strip() == '':
continue
else:
# page_markdown.append(para_text.strip() + ' ')
page_markdown.append(para_text.strip())
return page_markdown
def make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size):
para_type = para_block['type']
para_content = {}
if para_type in [
BlockType.TEXT,
BlockType.REF_TEXT,
BlockType.PHONETIC,
BlockType.HEADER,
BlockType.FOOTER,
BlockType.PAGE_NUMBER,
BlockType.ASIDE_TEXT,
BlockType.PAGE_FOOTNOTE,
]:
para_content = {
'type': para_type,
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.LIST:
para_content = {
'type': para_type,
'sub_type': para_block.get('sub_type', ''),
'list_items':[],
}
for block in para_block['blocks']:
item_text = merge_para_with_text(block)
if item_text.strip():
para_content['list_items'].append(item_text)
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_content = {
'type': ContentType.TEXT,
'text': merge_para_with_text(para_block),
}
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.INTERLINE_EQUATION:
para_content = {
'type': ContentType.EQUATION,
'text': merge_para_with_text(para_block),
'text_format': 'latex',
}
elif para_type == BlockType.IMAGE:
para_content = {'type': ContentType.IMAGE, 'img_path': '', BlockType.IMAGE_CAPTION: [], BlockType.IMAGE_FOOTNOTE: []}
for block in para_block['blocks']:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION:
para_content[BlockType.IMAGE_CAPTION].append(merge_para_with_text(block))
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content[BlockType.IMAGE_FOOTNOTE].append(merge_para_with_text(block))
elif para_type == BlockType.TABLE:
para_content = {'type': ContentType.TABLE, 'img_path': '', BlockType.TABLE_CAPTION: [], BlockType.TABLE_FOOTNOTE: []}
for block in para_block['blocks']:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.TABLE:
if span.get('html', ''):
para_content[BlockType.TABLE_BODY] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TABLE_CAPTION:
para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
elif para_type == BlockType.CODE:
para_content = {'type': BlockType.CODE, 'sub_type': para_block["sub_type"], BlockType.CODE_CAPTION: []}
for block in para_block['blocks']:
if block['type'] == BlockType.CODE_BODY:
para_content[BlockType.CODE_BODY] = merge_para_with_text(block)
if para_block["sub_type"] == BlockType.CODE:
para_content["guess_lang"] = para_block["guess_lang"]
if block['type'] == BlockType.CODE_CAPTION:
para_content[BlockType.CODE_CAPTION].append(merge_para_with_text(block))
page_width, page_height = page_size
para_bbox = para_block.get('bbox')
if para_bbox:
x0, y0, x1, y1 = para_bbox
para_content['bbox'] = [
int(x0 * 1000 / page_width),
int(y0 * 1000 / page_height),
int(x1 * 1000 / page_width),
int(y1 * 1000 / page_height),
]
para_content['page_idx'] = page_idx
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
img_buket_path: str = '',
):
formula_enable = get_formula_enable(os.getenv('MINERU_VLM_FORMULA_ENABLE', 'True').lower() == 'true')
table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
output_content = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
paras_of_discarded = page_info.get('discarded_blocks')
page_idx = page_info.get('page_idx')
page_size = page_info.get('page_size')
if not paras_of_layout:
continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, formula_enable, table_enable, img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.CONTENT_LIST:
for para_block in paras_of_layout+paras_of_discarded:
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_size)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
elif make_mode == MakeMode.CONTENT_LIST:
return output_content
return None
def get_title_level(block):
title_level = block.get('level', 1)
if title_level > 4:
title_level = 4
elif title_level < 1:
title_level = 0
return title_level
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
import os
import click
from pathlib import Path
from loguru import logger
from mineru.utils.cli_parser import arg_parse
from mineru.utils.config_reader import get_device
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
from mineru.utils.model_utils import get_vram
from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.version_option(__version__,
'--version',
'-v',
help='display the version and exit')
@click.option(
'-p',
'--path',
'input_path',
type=click.Path(exists=True),
required=True,
help='local filepath or directory. support pdf, png, jpg, jpeg files',
)
@click.option(
'-o',
'--output',
'output_dir',
type=click.Path(),
required=True,
help='output local directory',
)
@click.option(
'-m',
'--method',
'method',
type=click.Choice(['auto', 'txt', 'ocr']),
help="""the method for parsing pdf:
auto: Automatically determine the method based on the file type.
txt: Use text extraction method.
ocr: Use OCR method for image-based PDFs.
Without method specified, 'auto' will be used by default.
Adapted only for the case where the backend is set to "pipeline".""",
default='auto',
)
@click.option(
'-b',
'--backend',
'backend',
type=click.Choice(['pipeline', 'vlm-transformers', 'vlm-vllm-engine', 'vlm-http-client']),
help="""the backend for parsing pdf:
pipeline: More general.
vlm-transformers: More general.
vlm-vllm-engine: Faster(engine).
vlm-http-client: Faster(client).
without method specified, pipeline will be used by default.""",
default='pipeline',
)
@click.option(
'-l',
'--lang',
'lang',
type=click.Choice(['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka', 'th', 'el',
'latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']),
help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
Without languages specified, 'ch' will be used by default.
Adapted only for the case where the backend is set to "pipeline".
""",
default='ch',
)
@click.option(
'-u',
'--url',
'server_url',
type=str,
help="""
When the backend is `vlm-http-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
""",
default=None,
)
@click.option(
'-s',
'--start',
'start_page_id',
type=int,
help='The starting page for PDF parsing, beginning from 0.',
default=0,
)
@click.option(
'-e',
'--end',
'end_page_id',
type=int,
help='The ending page for PDF parsing, beginning from 0.',
default=None,
)
@click.option(
'-f',
'--formula',
'formula_enable',
type=bool,
help='Enable formula parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
default=True,
)
@click.option(
'-t',
'--table',
'table_enable',
type=bool,
help='Enable table parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
default=True,
)
@click.option(
'-d',
'--device',
'device_mode',
type=str,
help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Adapted only for the case where the backend is set to "pipeline". ',
default=None,
)
@click.option(
'--vram',
'virtual_vram',
type=int,
help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ',
default=None,
)
@click.option(
'--source',
'model_source',
type=click.Choice(['huggingface', 'modelscope', 'local']),
help="""
The source of the model repository. Default is 'huggingface'.
""",
default='huggingface',
)
def main(
ctx,
input_path, output_dir, method, backend, lang, server_url,
start_page_id, end_page_id, formula_enable, table_enable,
device_mode, virtual_vram, model_source, **kwargs
):
kwargs.update(arg_parse(ctx))
if not backend.endswith('-client'):
def get_device_mode() -> str:
if device_mode is not None:
return device_mode
else:
return get_device()
if os.getenv('MINERU_DEVICE_MODE', None) is None:
os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
def get_virtual_vram_size() -> int:
if virtual_vram is not None:
return virtual_vram
if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
return round(get_vram(get_device_mode()))
return 1
if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
if os.getenv('MINERU_MODEL_SOURCE', None) is None:
os.environ['MINERU_MODEL_SOURCE'] = model_source
os.makedirs(output_dir, exist_ok=True)
def parse_doc(path_list: list[Path]):
try:
file_name_list = []
pdf_bytes_list = []
lang_list = []
for path in path_list:
file_name = str(Path(path).stem)
pdf_bytes = read_fn(path)
file_name_list.append(file_name)
pdf_bytes_list.append(pdf_bytes)
lang_list.append(lang)
do_parse(
output_dir=output_dir,
pdf_file_names=file_name_list,
pdf_bytes_list=pdf_bytes_list,
p_lang_list=lang_list,
backend=backend,
parse_method=method,
formula_enable=formula_enable,
table_enable=table_enable,
server_url=server_url,
start_page_id=start_page_id,
end_page_id=end_page_id,
**kwargs,
)
except Exception as e:
logger.exception(e)
if os.path.isdir(input_path):
doc_path_list = []
for doc_path in Path(input_path).glob('*'):
if guess_suffix_by_path(doc_path) in pdf_suffixes + image_suffixes:
doc_path_list.append(doc_path)
parse_doc(doc_path_list)
else:
parse_doc([Path(input_path)])
if __name__ == '__main__':
main()
# Copyright (c) Opendatalab. All rights reserved.
import io
import json
import os
import copy
from pathlib import Path
import pypdfium2 as pdfium
from loguru import logger
from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox, draw_line_sort_bbox
from mineru.utils.enum_class import MakeMode
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_bytes
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
from mineru.backend.vlm.vlm_analyze import aio_doc_analyze as aio_vlm_doc_analyze
pdf_suffixes = ["pdf"]
image_suffixes = ["png", "jpeg", "jp2", "webp", "gif", "bmp", "jpg"]
def read_fn(path):
if not isinstance(path, Path):
path = Path(path)
with open(str(path), "rb") as input_file:
file_bytes = input_file.read()
file_suffix = guess_suffix_by_bytes(file_bytes)
if file_suffix in image_suffixes:
return images_bytes_to_pdf_bytes(file_bytes)
elif file_suffix in pdf_suffixes:
return file_bytes
else:
raise Exception(f"Unknown file suffix: {file_suffix}")
def prepare_env(output_dir, pdf_file_name, parse_method):
local_md_dir = str(os.path.join(output_dir, pdf_file_name, parse_method))
local_image_dir = os.path.join(str(local_md_dir), "images")
os.makedirs(local_image_dir, exist_ok=True)
os.makedirs(local_md_dir, exist_ok=True)
return local_image_dir, local_md_dir
def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None):
# 从字节数据加载PDF
pdf = pdfium.PdfDocument(pdf_bytes)
# 确定结束页
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf) - 1
if end_page_id > len(pdf) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length")
end_page_id = len(pdf) - 1
# 创建一个新的PDF文档
output_pdf = pdfium.PdfDocument.new()
# 选择要导入的页面索引
page_indices = list(range(start_page_id, end_page_id + 1))
# 从原PDF导入页面到新PDF
output_pdf.import_pages(pdf, page_indices)
# 将新PDF保存到内存缓冲区
output_buffer = io.BytesIO()
output_pdf.save(output_buffer)
# 获取字节数据
output_bytes = output_buffer.getvalue()
pdf.close() # 关闭原PDF文档以释放资源
output_pdf.close() # 关闭新PDF文档以释放资源
return output_bytes
def _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id):
"""准备处理PDF字节数据"""
result = []
for pdf_bytes in pdf_bytes_list:
new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
result.append(new_pdf_bytes)
return result
def _process_output(
pdf_info,
pdf_bytes,
pdf_file_name,
local_md_dir,
local_image_dir,
md_writer,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_orig_pdf,
f_dump_md,
f_dump_content_list,
f_dump_middle_json,
f_dump_model_output,
f_make_md_mode,
middle_json,
model_output=None,
is_pipeline=True
):
f_draw_line_sort_bbox = False
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
"""处理输出文件"""
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
if f_dump_orig_pdf:
md_writer.write(
f"{pdf_file_name}_origin.pdf",
pdf_bytes,
)
if f_draw_line_sort_bbox:
draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_line_sort.pdf")
image_dir = str(os.path.basename(local_image_dir))
if f_dump_md:
make_func = pipeline_union_make if is_pipeline else vlm_union_make
md_content_str = make_func(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
if f_dump_content_list:
make_func = pipeline_union_make if is_pipeline else vlm_union_make
content_list = make_func(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if f_dump_model_output:
md_writer.write_string(
f"{pdf_file_name}_model.json",
json.dumps(model_output, ensure_ascii=False, indent=4),
)
logger.info(f"local output dir is {local_md_dir}")
def _process_pipeline(
output_dir,
pdf_file_names,
pdf_bytes_list,
p_lang_list,
parse_method,
p_formula_enable,
p_table_enable,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
):
"""处理pipeline后端逻辑"""
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
pipeline_doc_analyze(
pdf_bytes_list, p_lang_list, parse_method=parse_method,
formula_enable=p_formula_enable, table_enable=p_table_enable
)
)
for idx, model_list in enumerate(infer_results):
model_json = copy.deepcopy(model_list)
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
images_list = all_image_lists[idx]
pdf_doc = all_pdf_docs[idx]
_lang = lang_list[idx]
_ocr_enable = ocr_enabled_list[idx]
middle_json = pipeline_result_to_middle_json(
model_list, images_list, pdf_doc, image_writer,
_lang, _ocr_enable, p_formula_enable
)
pdf_info = middle_json["pdf_info"]
pdf_bytes = pdf_bytes_list[idx]
_process_output(
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f_make_md_mode, middle_json, model_json, is_pipeline=True
)
async def _async_process_vlm(
output_dir,
pdf_file_names,
pdf_bytes_list,
backend,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
server_url=None,
**kwargs,
):
"""异步处理VLM后端逻辑"""
parse_method = "vlm"
f_draw_span_bbox = False
if not backend.endswith("client"):
server_url = None
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = await aio_vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
)
pdf_info = middle_json["pdf_info"]
_process_output(
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f_make_md_mode, middle_json, infer_result, is_pipeline=False
)
def _process_vlm(
output_dir,
pdf_file_names,
pdf_bytes_list,
backend,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
server_url=None,
**kwargs,
):
"""同步处理VLM后端逻辑"""
parse_method = "vlm"
f_draw_span_bbox = False
if not backend.endswith("client"):
server_url = None
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
)
pdf_info = middle_json["pdf_info"]
_process_output(
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f_make_md_mode, middle_json, infer_result, is_pipeline=False
)
def do_parse(
output_dir,
pdf_file_names: list[str],
pdf_bytes_list: list[bytes],
p_lang_list: list[str],
backend="pipeline",
parse_method="auto",
formula_enable=True,
table_enable=True,
server_url=None,
f_draw_layout_bbox=True,
f_draw_span_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_output=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
**kwargs,
):
# 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
if backend == "pipeline":
_process_pipeline(
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
parse_method, formula_enable, table_enable,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
)
else:
if backend.startswith("vlm-"):
backend = backend[4:]
if backend == "vllm-async-engine":
raise Exception("vlm-vllm-async-engine backend is not supported in sync mode, please use vlm-vllm-engine backend")
os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
_process_vlm(
output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url, **kwargs,
)
async def aio_do_parse(
output_dir,
pdf_file_names: list[str],
pdf_bytes_list: list[bytes],
p_lang_list: list[str],
backend="pipeline",
parse_method="auto",
formula_enable=True,
table_enable=True,
server_url=None,
f_draw_layout_bbox=True,
f_draw_span_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_output=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
**kwargs,
):
# 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
if backend == "pipeline":
# pipeline模式暂不支持异步,使用同步处理方式
_process_pipeline(
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
parse_method, formula_enable, table_enable,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
)
else:
if backend.startswith("vlm-"):
backend = backend[4:]
if backend == "vllm-engine":
raise Exception("vlm-vllm-engine backend is not supported in async mode, please use vlm-vllm-async-engine backend")
os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
await _async_process_vlm(
output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url, **kwargs,
)
if __name__ == "__main__":
# pdf_path = "../../demo/pdfs/demo3.pdf"
pdf_path = "C:/Users/zhaoxiaomeng/Downloads/4546d0e2-ba60-40a5-a17e-b68555cec741.pdf"
try:
do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"],
end_page_id=10,
backend='vlm-huggingface'
# backend = 'pipeline'
)
except Exception as e:
logger.exception(e)
import uuid
import os
import re
import tempfile
import asyncio
import uvicorn
import click
import zipfile
from pathlib import Path
import glob
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse, FileResponse
from starlette.background import BackgroundTask
from typing import List, Optional
from loguru import logger
from base64 import b64encode
from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
from mineru.utils.cli_parser import arg_parse
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
from mineru.version import __version__
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
def sanitize_filename(filename: str) -> str:
"""
格式化压缩文件的文件名
移除路径遍历字符, 保留 Unicode 字母、数字、._-
禁止隐藏文件
"""
sanitized = re.sub(r'[/\\\.]{2,}|[/\\]', '', filename)
sanitized = re.sub(r'[^\w.-]', '_', sanitized, flags=re.UNICODE)
if sanitized.startswith('.'):
sanitized = '_' + sanitized[1:]
return sanitized or 'unnamed'
def cleanup_file(file_path: str) -> None:
"""清理临时 zip 文件"""
try:
if os.path.exists(file_path):
os.remove(file_path)
except Exception as e:
logger.warning(f"fail clean file {file_path}: {e}")
def encode_image(image_path: str) -> str:
"""Encode image using base64"""
with open(image_path, "rb") as f:
return b64encode(f.read()).decode()
def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
"""从结果文件中读取推理结果"""
result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
if os.path.exists(result_file_path):
with open(result_file_path, "r", encoding="utf-8") as fp:
return fp.read()
return None
@app.post(path="/file_parse",)
async def parse_pdf(
files: List[UploadFile] = File(...),
output_dir: str = Form("./output"),
lang_list: List[str] = Form(["ch"]),
backend: str = Form("pipeline"),
parse_method: str = Form("auto"),
formula_enable: bool = Form(True),
table_enable: bool = Form(True),
server_url: Optional[str] = Form(None),
return_md: bool = Form(True),
return_middle_json: bool = Form(False),
return_model_output: bool = Form(False),
return_content_list: bool = Form(False),
return_images: bool = Form(False),
response_format_zip: bool = Form(False),
start_page_id: int = Form(0),
end_page_id: int = Form(99999),
):
# 获取命令行配置参数
config = getattr(app.state, "config", {})
try:
# 创建唯一的输出目录
unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
os.makedirs(unique_dir, exist_ok=True)
# 处理上传的PDF文件
pdf_file_names = []
pdf_bytes_list = []
for file in files:
content = await file.read()
file_path = Path(file.filename)
# 创建临时文件
temp_path = Path(unique_dir) / file_path.name
with open(temp_path, "wb") as f:
f.write(content)
# 如果是图像文件或PDF,使用read_fn处理
file_suffix = guess_suffix_by_path(temp_path)
if file_suffix in pdf_suffixes + image_suffixes:
try:
pdf_bytes = read_fn(temp_path)
pdf_bytes_list.append(pdf_bytes)
pdf_file_names.append(file_path.stem)
os.remove(temp_path) # 删除临时文件
except Exception as e:
return JSONResponse(
status_code=400,
content={"error": f"Failed to load file: {str(e)}"}
)
else:
return JSONResponse(
status_code=400,
content={"error": f"Unsupported file type: {file_suffix}"}
)
# 设置语言列表,确保与文件数量一致
actual_lang_list = lang_list
if len(actual_lang_list) != len(pdf_file_names):
# 如果语言列表长度不匹配,使用第一个语言或默认"ch"
actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
# 调用异步处理函数
await aio_do_parse(
output_dir=unique_dir,
pdf_file_names=pdf_file_names,
pdf_bytes_list=pdf_bytes_list,
p_lang_list=actual_lang_list,
backend=backend,
parse_method=parse_method,
formula_enable=formula_enable,
table_enable=table_enable,
server_url=server_url,
f_draw_layout_bbox=False,
f_draw_span_bbox=False,
f_dump_md=return_md,
f_dump_middle_json=return_middle_json,
f_dump_model_output=return_model_output,
f_dump_orig_pdf=False,
f_dump_content_list=return_content_list,
start_page_id=start_page_id,
end_page_id=end_page_id,
**config
)
# 根据 response_format_zip 决定返回类型
if response_format_zip:
zip_fd, zip_path = tempfile.mkstemp(suffix=".zip", prefix="mineru_results_")
os.close(zip_fd)
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for pdf_name in pdf_file_names:
safe_pdf_name = sanitize_filename(pdf_name)
if backend.startswith("pipeline"):
parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
else:
parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
if not os.path.exists(parse_dir):
continue
# 写入文本类结果
if return_md:
path = os.path.join(parse_dir, f"{pdf_name}.md")
if os.path.exists(path):
zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}.md"))
if return_middle_json:
path = os.path.join(parse_dir, f"{pdf_name}_middle.json")
if os.path.exists(path):
zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}_middle.json"))
if return_model_output:
if backend.startswith("pipeline"):
path = os.path.join(parse_dir, f"{pdf_name}_model.json")
else:
path = os.path.join(parse_dir, f"{pdf_name}_model_output.txt")
if os.path.exists(path):
zf.write(path, arcname=os.path.join(safe_pdf_name, os.path.basename(path)))
if return_content_list:
path = os.path.join(parse_dir, f"{pdf_name}_content_list.json")
if os.path.exists(path):
zf.write(path, arcname=os.path.join(safe_pdf_name, f"{safe_pdf_name}_content_list.json"))
# 写入图片
if return_images:
images_dir = os.path.join(parse_dir, "images")
image_paths = glob.glob(os.path.join(glob.escape(images_dir), "*.jpg"))
for image_path in image_paths:
zf.write(image_path, arcname=os.path.join(safe_pdf_name, "images", os.path.basename(image_path)))
return FileResponse(
path=zip_path,
media_type="application/zip",
filename="results.zip",
background=BackgroundTask(cleanup_file, zip_path)
)
else:
# 构建 JSON 结果
result_dict = {}
for pdf_name in pdf_file_names:
result_dict[pdf_name] = {}
data = result_dict[pdf_name]
if backend.startswith("pipeline"):
parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
else:
parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
if os.path.exists(parse_dir):
if return_md:
data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
if return_middle_json:
data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
if return_model_output:
if backend.startswith("pipeline"):
data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
else:
data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
if return_content_list:
data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
if return_images:
images_dir = os.path.join(parse_dir, "images")
safe_pattern = os.path.join(glob.escape(images_dir), "*.jpg")
image_paths = glob.glob(safe_pattern)
data["images"] = {
os.path.basename(
image_path
): f"data:image/jpeg;base64,{encode_image(image_path)}"
for image_path in image_paths
}
return JSONResponse(
status_code=200,
content={
"backend": backend,
"version": __version__,
"results": result_dict
}
)
except Exception as e:
logger.exception(e)
return JSONResponse(
status_code=500,
content={"error": f"Failed to process file: {str(e)}"}
)
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
@click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
def main(ctx, host, port, reload, **kwargs):
kwargs.update(arg_parse(ctx))
# 将配置参数存储到应用状态中
app.state.config = kwargs
"""启动MinerU FastAPI服务器的命令行入口"""
print(f"Start MinerU FastAPI Service: http://{host}:{port}")
print("The API documentation can be accessed at the following address:")
print(f"- Swagger UI: http://{host}:{port}/docs")
print(f"- ReDoc: http://{host}:{port}/redoc")
uvicorn.run(
"mineru.cli.fast_api:app",
host=host,
port=port,
reload=reload
)
if __name__ == "__main__":
main()
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import base64
import os
import re
import time
import zipfile
from pathlib import Path
import click
import gradio as gr
from gradio_pdf import PDF
from loguru import logger
from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
from mineru.utils.cli_parser import arg_parse
from mineru.utils.hash_utils import str_sha256
async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language, backend, url):
os.makedirs(output_dir, exist_ok=True)
try:
file_name = f'{safe_stem(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}'
pdf_data = read_fn(doc_path)
if is_ocr:
parse_method = 'ocr'
else:
parse_method = 'auto'
if backend.startswith("vlm"):
parse_method = "vlm"
local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
await aio_do_parse(
output_dir=output_dir,
pdf_file_names=[file_name],
pdf_bytes_list=[pdf_data],
p_lang_list=[language],
parse_method=parse_method,
end_page_id=end_page_id,
formula_enable=formula_enable,
table_enable=table_enable,
backend=backend,
server_url=url,
)
return local_md_dir, file_name
except Exception as e:
logger.exception(e)
return None
def compress_directory_to_zip(directory_path, output_zip_path):
"""压缩指定目录到一个 ZIP 文件。
:param directory_path: 要压缩的目录路径
:param output_zip_path: 输出的 ZIP 文件路径
"""
try:
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
# 遍历目录中的所有文件和子目录
for root, dirs, files in os.walk(directory_path):
for file in files:
# 构建完整的文件路径
file_path = os.path.join(root, file)
# 计算相对路径
arcname = os.path.relpath(file_path, directory_path)
# 添加文件到 ZIP 文件
zipf.write(file_path, arcname)
return 0
except Exception as e:
logger.exception(e)
return -1
def image_to_base64(image_path):
with open(image_path, 'rb') as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def replace_image_with_base64(markdown_text, image_dir_path):
# 匹配Markdown中的图片标签
pattern = r'\!\[(?:[^\]]*)\]\(([^)]+)\)'
# 替换图片链接
def replace(match):
relative_path = match.group(1)
full_path = os.path.join(image_dir_path, relative_path)
base64_image = image_to_base64(full_path)
return f'![{relative_path}](data:image/jpeg;base64,{base64_image})'
# 应用替换
return re.sub(pattern, replace, markdown_text)
async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True, table_enable=True, language="ch", backend="pipeline", url=None):
file_path = to_pdf(file_path)
# 获取识别的md文件以及压缩包文件路径
local_md_dir, file_name = await parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language, backend, url)
archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip')
zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
if zip_archive_success == 0:
logger.info('Compression successful')
else:
logger.error('Compression failed')
md_path = os.path.join(local_md_dir, file_name + '.md')
with open(md_path, 'r', encoding='utf-8') as f:
txt_content = f.read()
md_content = replace_image_with_base64(txt_content, local_md_dir)
# 返回转换后的PDF路径
new_pdf_path = os.path.join(local_md_dir, file_name + '_layout.pdf')
return md_content, txt_content, archive_zip_path, new_pdf_path
latex_delimiters_type_a = [
{'left': '$$', 'right': '$$', 'display': True},
{'left': '$', 'right': '$', 'display': False},
]
latex_delimiters_type_b = [
{'left': '\\(', 'right': '\\)', 'display': False},
{'left': '\\[', 'right': '\\]', 'display': True},
]
latex_delimiters_type_all = latex_delimiters_type_a + latex_delimiters_type_b
header_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources', 'header.html')
with open(header_path, 'r') as header_file:
header = header_file.read()
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
east_slavic_lang = ["ru", "be", "uk"]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
'sa', 'bgc'
]
other_lang = ['ch', 'ch_lite', 'ch_server', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka', "el", "th"]
add_lang = ['latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']
# all_lang = ['', 'auto']
all_lang = []
# all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
all_lang.extend([*other_lang, *add_lang])
def safe_stem(file_path):
stem = Path(file_path).stem
# 只保留字母、数字、下划线和点,其他字符替换为下划线
return re.sub(r'[^\w.]', '_', stem)
def to_pdf(file_path):
if file_path is None:
return None
pdf_bytes = read_fn(file_path)
# unique_filename = f'{uuid.uuid4()}.pdf'
unique_filename = f'{safe_stem(file_path)}.pdf'
# 构建完整的文件路径
tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
# 将字节数据写入文件
with open(tmp_file_path, 'wb') as tmp_pdf_file:
tmp_pdf_file.write(pdf_bytes)
return tmp_file_path
# 更新界面函数
def update_interface(backend_choice):
if backend_choice in ["vlm-transformers", "vlm-vllm-async-engine"]:
return gr.update(visible=False), gr.update(visible=False)
elif backend_choice in ["vlm-http-client"]:
return gr.update(visible=True), gr.update(visible=False)
elif backend_choice in ["pipeline"]:
return gr.update(visible=False), gr.update(visible=True)
else:
pass
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option(
'--enable-example',
'example_enable',
type=bool,
help="Enable example files for input."
"The example files to be input need to be placed in the `example` folder within the directory where the command is currently executed.",
default=True,
)
@click.option(
'--enable-vllm-engine',
'vllm_engine_enable',
type=bool,
help="Enable vLLM engine backend for faster processing.",
default=False,
)
@click.option(
'--enable-api',
'api_enable',
type=bool,
help="Enable gradio API for serving the application.",
default=True,
)
@click.option(
'--max-convert-pages',
'max_convert_pages',
type=int,
help="Set the maximum number of pages to convert from PDF to Markdown.",
default=1000,
)
@click.option(
'--server-name',
'server_name',
type=str,
help="Set the server name for the Gradio app.",
default=None,
)
@click.option(
'--server-port',
'server_port',
type=int,
help="Set the server port for the Gradio app.",
default=None,
)
@click.option(
'--latex-delimiters-type',
'latex_delimiters_type',
type=click.Choice(['a', 'b', 'all']),
help="Set the type of LaTeX delimiters to use in Markdown rendering:"
"'a' for type '$', 'b' for type '()[]', 'all' for both types.",
default='all',
)
def main(ctx,
example_enable, vllm_engine_enable, api_enable, max_convert_pages,
server_name, server_port, latex_delimiters_type, **kwargs
):
kwargs.update(arg_parse(ctx))
if latex_delimiters_type == 'a':
latex_delimiters = latex_delimiters_type_a
elif latex_delimiters_type == 'b':
latex_delimiters = latex_delimiters_type_b
elif latex_delimiters_type == 'all':
latex_delimiters = latex_delimiters_type_all
else:
raise ValueError(f"Invalid latex delimiters type: {latex_delimiters_type}.")
if vllm_engine_enable:
try:
print("Start init vLLM engine...")
from mineru.backend.vlm.vlm_analyze import ModelSingleton
model_singleton = ModelSingleton()
predictor = model_singleton.get_model(
"vllm-async-engine",
None,
None,
**kwargs
)
print("vLLM engine init successfully.")
except Exception as e:
logger.exception(e)
suffixes = [f".{suffix}" for suffix in pdf_suffixes + image_suffixes]
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column(variant='panel', scale=5):
with gr.Row():
input_file = gr.File(label='Please upload a PDF or image', file_types=suffixes)
with gr.Row():
max_pages = gr.Slider(1, max_convert_pages, int(max_convert_pages/2), step=1, label='Max convert pages')
with gr.Row():
if vllm_engine_enable:
drop_list = ["pipeline", "vlm-vllm-async-engine"]
preferred_option = "vlm-vllm-async-engine"
else:
drop_list = ["pipeline", "vlm-transformers", "vlm-http-client"]
preferred_option = "pipeline"
backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
with gr.Row(visible=False) as client_options:
url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("**Recognition Options:**")
formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
table_enable = gr.Checkbox(label='Enable table recognition', value=True)
with gr.Column(visible=False) as ocr_options:
language = gr.Dropdown(all_lang, label='Language', value='ch')
is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
with gr.Row():
change_bu = gr.Button('Convert')
clear_bu = gr.ClearButton(value='Clear')
pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800)
if example_enable:
example_root = os.path.join(os.getcwd(), 'examples')
if os.path.exists(example_root):
with gr.Accordion('Examples:'):
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith(tuple(suffixes))],
inputs=input_file
)
with gr.Column(variant='panel', scale=5):
output_file = gr.File(label='convert result', interactive=False)
with gr.Tabs():
with gr.Tab('Markdown rendering'):
md = gr.Markdown(label='Markdown rendering', height=1100, show_copy_button=True,
latex_delimiters=latex_delimiters,
line_breaks=True)
with gr.Tab('Markdown text'):
md_text = gr.TextArea(lines=45, show_copy_button=True)
# 添加事件处理
backend.change(
fn=update_interface,
inputs=[backend],
outputs=[client_options, ocr_options],
api_name=False
)
# 添加demo.load事件,在页面加载时触发一次界面更新
demo.load(
fn=update_interface,
inputs=[backend],
outputs=[client_options, ocr_options],
api_name=False
)
clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
if api_enable:
api_name = None
else:
api_name = False
input_file.change(fn=to_pdf, inputs=input_file, outputs=pdf_show, api_name=api_name)
change_bu.click(
fn=to_markdown,
inputs=[input_file, max_pages, is_ocr, formula_enable, table_enable, language, backend, url],
outputs=[md, md_text, output_file, pdf_show],
api_name=api_name
)
demo.launch(server_name=server_name, server_port=server_port, show_api=api_enable)
if __name__ == '__main__':
main()
\ No newline at end of file
import json
import os
import sys
import click
import requests
from loguru import logger
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
def download_json(url):
"""下载JSON文件"""
response = requests.get(url)
response.raise_for_status()
return response.json()
def download_and_modify_json(url, local_filename, modifications):
"""下载JSON并修改内容"""
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.3.0':
data = download_json(url)
else:
data = download_json(url)
# 修改内容
for key, value in modifications.items():
if key in data:
if isinstance(data[key], dict):
# 如果是字典,合并新值
data[key].update(value)
else:
# 否则直接替换
data[key] = value
# 保存修改后的内容
with open(local_filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def configure_model(model_dir, model_type):
"""配置模型"""
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/mineru.template.json'
config_file_name = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
'models-dir': {
f'{model_type}': model_dir
}
}
download_and_modify_json(json_url, config_file, json_mods)
logger.info(f'The configuration file has been successfully configured, the path is: {config_file}')
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus,
ModelPath.unet_structure,
ModelPath.paddle_table_cls,
ModelPath.paddle_orientation_classification,
]
download_finish_path = ""
for model_path in model_paths:
logger.info(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
logger.info(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "pipeline")
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
logger.info(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "vlm")
@click.command()
@click.option(
'-s',
'--source',
'model_source',
type=click.Choice(['huggingface', 'modelscope']),
help="""
The source of the model repository.
""",
default=None,
)
@click.option(
'-m',
'--model_type',
'model_type',
type=click.Choice(['pipeline', 'vlm', 'all']),
help="""
The type of the model to download.
""",
default=None,
)
def download_models(model_source, model_type):
"""Download MinerU model files.
Supports downloading pipeline or VLM models from ModelScope or HuggingFace.
"""
# 如果未显式指定则交互式输入下载来源
if model_source is None:
model_source = click.prompt(
"Please select the model download source: ",
type=click.Choice(['huggingface', 'modelscope']),
default='huggingface'
)
if os.getenv('MINERU_MODEL_SOURCE', None) is None:
os.environ['MINERU_MODEL_SOURCE'] = model_source
# 如果未显式指定则交互式输入模型类型
if model_type is None:
model_type = click.prompt(
"Please select the model type to download: ",
type=click.Choice(['pipeline', 'vlm', 'all']),
default='all'
)
logger.info(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
try:
if model_type == 'pipeline':
download_pipeline_models()
elif model_type == 'vlm':
download_vlm_models()
elif model_type == 'all':
download_pipeline_models()
download_vlm_models()
else:
click.echo(f"Unsupported model type: {model_type}", err=True)
sys.exit(1)
except Exception as e:
logger.exception(f"An error occurred while downloading models: {str(e)}")
sys.exit(1)
if __name__ == '__main__':
download_models()
from mineru.model.vlm_vllm_model.server import main
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment