Unverified Commit 0c7a0882 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2611 from myhloli/dev

Dev
parents 3bd0ecf1 a392f445
import cv2
from loguru import logger
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from .model_init import AtomModelSingleton
from ...utils.model_utils import crop_img, get_res_list_from_layout_res, get_coords_and_area
from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
self.batch_ratio = batch_ratio
self.formula_enable = formula_enable
self.table_enable = table_enable
self.model_manager = model_manager
self.enable_ocr_det_batch = enable_ocr_det_batch
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
return []
images_layout_res = []
self.model = self.model_manager.get_model(
lang=None,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
atom_model_manager = AtomModelSingleton()
images = [image for image, _, _ in images_with_extra_info]
# doclayout_yolo
layout_images = []
for image_index, image in enumerate(images):
layout_images.append(image)
images_layout_res += self.model.layout_model.batch_predict(
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
if self.formula_enable:
# 公式检测
images_mfd_res = self.model.mfd_model.batch_predict(
images, MFD_BASE_BATCH_SIZE
)
# 公式识别
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
mfr_count = 0
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
# 清理显存
# clean_vram(self.model.device, vram_threshold=8)
ocr_res_list_all_page = []
table_res_list_all_page = []
for index in range(len(images)):
_, ocr_enable, _lang = images_with_extra_info[index]
layout_res = images_layout_res[index]
pil_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
'lang':_lang,
'ocr_enable':ocr_enable,
'pil_img':pil_img,
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
for table_res in table_res_list:
table_img, _ = crop_img(table_res, pil_img)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':table_img,
})
# OCR检测处理
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.3,
lang=lang
)
# 按分辨率分组并同时完成padding
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for crop_info in group_crops:
img = crop_info[0]
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None and len(dt_boxes) > 0:
# 直接应用原始OCR流程中的关键处理步骤
from mineru.utils.ocr_utils import (
merge_det_boxes, update_det_boxes, sorted_boxes
)
# 1. 排序检测框
if len(dt_boxes) > 0:
dt_boxes_sorted = sorted_boxes(dt_boxes)
else:
dt_boxes_sorted = []
# 2. 合并相邻检测框
if dt_boxes_sorted:
dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
else:
dt_boxes_merged = []
# 3. 根据公式位置更新检测框(关键步骤!)
if dt_boxes_merged and adjusted_mfdetrec_res:
dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
else:
dt_boxes_final = dt_boxes_merged
# 构造OCR结果格式
ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# 表格识别 table recognition
if self.table_enable:
for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_dict['lang']
table_model = atom_model_manager.get_atom_model(
atom_model_name='table',
lang=_lang,
)
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
if expected_ending:
table_res_dict['table_res']['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
'table recognition processing fails, not get html return'
)
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
lang = layout_res_item['lang']
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(f"{ocr_score:.3f}")
if ocr_score < 0.6:
layout_res_item['category_id'] = 16
total_processed += len(img_crop_list)
return images_layout_res
import os
import torch
from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
# try:
# from magic_pdf_ascend_plugin.libs.license_verifier import (
# LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
# load_license)
# from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
# from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
# license_key = load_license()
# logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
# f' License expired at {license_key["payload"]["date"]["end_date"]}')
# except Exception as e:
# if isinstance(e, ImportError):
# pass
# elif isinstance(e, LicenseFormatError):
# logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
# elif isinstance(e, LicenseSignatureError):
# logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
# elif isinstance(e, LicenseExpiredError):
# logger.error('Ascend Plugin: License has expired. Please renew your license.')
# elif isinstance(e, FileNotFoundError):
# logger.error('Ascend Plugin: Not found License file.')
# else:
# logger.error(f'Ascend Plugin: {e}')
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
config = {
'model_dir': model_path,
'device': _device_
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang
)
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else:
logger.error('table model type not allow')
exit(1)
from .model_list import AtomicModel
from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from ...model.table.rapid_table import RapidTableModel
from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import auto_download_and_get_model_root_path
def table_model_init(lang=None):
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang
)
table_model = RapidTableModel(ocr_engine)
return table_model
......@@ -71,50 +32,31 @@ def mfd_model_init(weight, device='cpu'):
return mfd_model
def mfr_model_init(weight_dir, cfg_path, device='cpu'):
mfr_model = UnimernetModel(weight_dir, cfg_path, device)
def mfr_model_init(weight_dir, device='cpu'):
mfr_model = UnimernetModel(weight_dir, device)
return mfr_model
def layout_model_init(weight, config_file, device):
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device)
return model
def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model
def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3,
def ocr_model_init(det_db_box_thresh=0.3,
lang=None,
use_dilation=True,
det_db_unclip_ratio=1.8,
):
if lang is not None and lang != '':
# model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
lang=lang,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
)
else:
# model = ModifiedPaddleOCR(
model = PytorchPaddleOCR(
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
......@@ -134,13 +76,10 @@ class AtomModelSingleton:
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get('lang', None)
layout_model_name = kwargs.get('layout_model_name', None)
table_model_name = kwargs.get('table_model_name', None)
if atom_model_name in [AtomicModel.OCR]:
key = (atom_model_name, lang)
elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name, lang)
else:
......@@ -153,20 +92,10 @@ class AtomModelSingleton:
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init(
kwargs.get('layout_weights'),
kwargs.get('layout_config_file'),
kwargs.get('device')
)
elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init(
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
else:
logger.error('layout model name not allow')
exit(1)
atom_model = doclayout_yolo_model_init(
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
......@@ -175,33 +104,17 @@ def atom_model_init(model_name: str, **kwargs):
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get('mfr_weight_dir'),
kwargs.get('mfr_cfg_path'),
kwargs.get('device')
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get('ocr_show_log'),
kwargs.get('det_db_box_thresh'),
kwargs.get('lang'),
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get('table_model_name'),
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device'),
kwargs.get('lang'),
kwargs.get('table_sub_model_name')
)
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
atom_model = langdetect_model_init(
kwargs.get('langdetect_model_weight'),
kwargs.get('device')
)
else:
logger.error('langdetect model name not allow')
exit(1)
else:
logger.error('model name not allow')
exit(1)
......@@ -211,3 +124,59 @@ def atom_model_init(model_name: str, **kwargs):
exit(1)
else:
return atom_model
class MineruPipelineModel:
def __init__(self, **kwargs):
self.formula_config = kwargs.get('formula_config')
self.apply_formula = self.formula_config.get('enable', True)
self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get('enable', True)
self.lang = kwargs.get('lang', None)
self.device = kwargs.get('device', 'cpu')
logger.info(
'DocAnalysis init, this may take some times......'
)
atom_model_manager = AtomModelSingleton()
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
device=self.device,
)
# 初始化layout模型
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
),
device=self.device,
)
# 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
lang=self.lang,
)
logger.info('DocAnalysis init done!')
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from mineru.utils.config_reader import get_device, get_llm_aided_config
from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.backend.pipeline.para_split import para_split
from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
from mineru.utils.block_sort import sort_blocks_by_bbox
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import ContentType
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.model_utils import clean_memory
from mineru.backend.pipeline.pipeline_magic_model import MagicModel
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, txt_spans_extract
from mineru.version import __version__
from mineru.utils.hash_utils import str_md5
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
page_w, page_h = map(int, page.get_size())
magic_model = MagicModel(page_model_info, scale)
"""从magic_model对象中获取后面会用到的区块信息"""
discarded_blocks = magic_model.get_discarded()
text_blocks = magic_model.get_text_blocks()
title_blocks = magic_model.get_title_blocks()
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
img_groups = magic_model.get_imgs()
table_groups = magic_model.get_tables()
"""对image和table的区块分组"""
img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
)
"""获取所有的spans信息"""
spans = magic_model.get_all_spans()
"""某些图可能是文本块,通过简单的规则判断一下"""
if len(maybe_text_image_blocks) > 0:
for block in maybe_text_image_blocks:
span_in_block_list = []
for span in spans:
if span['type'] == 'text' and calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7:
span_in_block_list.append(span)
if len(span_in_block_list) > 0:
# span_in_block_list中所有bbox的面积之和
spans_area = sum((span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1]) for span in span_in_block_list)
# 求ocr_res_area和res的面积的比值
block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
if block_area > 0:
ratio = spans_area / block_area
if ratio > 0.25 and ocr_enable:
# 移除block的group_id
block.pop('group_id', None)
# 符合文本图的条件就把块加入到文本块列表中
text_blocks.append(block)
else:
# 如果不符合文本图的条件,就把块加回到图片块列表中
img_body_blocks.append(block)
else:
img_body_blocks.append(block)
"""将所有区块的bbox整理到一起"""
if formula_enabled:
interline_equation_blocks = []
if len(interline_equation_blocks) > 0:
for block in interline_equation_blocks:
spans.append({
"type": ContentType.INTERLINE_EQUATION,
'score': block['score'],
"bbox": block['bbox'],
})
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else:
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
if ocr_enable:
pass
else:
"""使用新版本的混合ocr方案."""
spans = txt_spans_extract(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
"""如果当前页面没有有效的bbox则跳过"""
if len(all_bboxes) == 0:
return None
"""对image/table/interline_equation截图"""
for span in spans:
if span['type'] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(
span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
)
"""span填充进block"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
"""对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans)
"""同一行被断开的titile合并"""
# merge_title_blocks(fix_blocks)
"""对block进行排序"""
sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
"""构造page_info"""
page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
return page_info
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False, formula_enabled=True):
middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__}
for page_index, page_model_info in enumerate(model_list):
page = pdf_doc[page_index]
image_dict = images_list[page_index]
page_info = page_model_info_to_page_info(
page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable, formula_enabled=formula_enabled
)
if page_info is None:
page_w, page_h = map(int, page.get_size())
page_info = make_page_info_dict([], page_index, page_w, page_h, [])
middle_json["pdf_info"].append(page_info)
"""后置ocr处理"""
need_ocr_list = []
img_crop_list = []
text_block_list = []
for page_info in middle_json["pdf_info"]:
for block in page_info['preproc_blocks']:
if block['type'] in ['table', 'image']:
for sub_block in block['blocks']:
if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
text_block_list.append(sub_block)
elif block['type'] in ['text', 'title']:
text_block_list.append(block)
for block in page_info['discarded_blocks']:
text_block_list.append(block)
for block in text_block_list:
for line in block['lines']:
for span in line['spans']:
if 'np_img' in span:
need_ocr_list.append(span)
img_crop_list.append(span['np_img'])
span.pop('np_img')
if len(img_crop_list) > 0:
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
assert len(ocr_res_list) == len(
need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
for index, span in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
if ocr_score > 0.6:
span['content'] = ocr_text
span['score'] = float(f"{ocr_score:.3f}")
else:
span['content'] = ''
span['score'] = 0.0
"""分段"""
para_split(middle_json["pdf_info"])
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
"""清理内存"""
pdf_doc.close()
clean_memory(get_device())
return middle_json
def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
return_dict = {
'preproc_blocks': blocks,
'page_idx': page_id,
'page_size': [page_w, page_h],
'discarded_blocks': discarded_blocks,
}
return return_dict
\ No newline at end of file
class MODEL:
Paddle = "pp_structure_v2"
PEK = "pdf_extract_kit"
class AtomicModel:
Layout = "layout"
MFD = "mfd"
MFR = "mfr"
OCR = "ocr"
Table = "table"
LangDetect = "langdetect"
import copy
from loguru import logger
from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
from mineru.utils.language import detect_lang
from magic_pdf.config.constants import CROSS_PAGE, LINES_DELETED
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.language import detect_lang
LINE_STOP_FLAG = (
'.',
'!',
'?',
'。',
'!',
'?',
')',
')',
'"',
'”',
':',
':',
';',
';',
)
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
LIST_END_FLAG = ('.', '。', ';', ';')
......@@ -115,7 +98,7 @@ def __is_list_or_index_block(block):
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.Text:
if span_type == ContentType.TEXT:
line_text += span['content'].strip()
# 添加所有文本,包括空行,保持与block['lines']长度一致
lines_text_list.append(line_text)
......@@ -191,7 +174,7 @@ def __is_list_or_index_block(block):
) and line_num_flag:
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.Index
return BlockType.INDEX
# 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
# 补充条件block的长宽比有要求
......@@ -203,7 +186,7 @@ def __is_list_or_index_block(block):
):
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.List
return BlockType.LIST
elif (
left_close_num >= 2
......@@ -260,11 +243,11 @@ def __is_list_or_index_block(block):
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
return BlockType.List
return BlockType.LIST
else:
return BlockType.Text
return BlockType.TEXT
else:
return BlockType.Text
return BlockType.TEXT
def __merge_2_text_blocks(block1, block2):
......@@ -299,10 +282,10 @@ def __merge_2_text_blocks(block1, block2):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
span[SplitFlag.CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
block1[SplitFlag.LINES_DELETED] = True
return block1, block2
......@@ -311,10 +294,10 @@ def __merge_2_list_blocks(block1, block2):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
span[SplitFlag.CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
block1[SplitFlag.LINES_DELETED] = True
return block1, block2
......@@ -357,11 +340,11 @@ def __para_merge_page(blocks):
):
__merge_2_text_blocks(current_block, prev_block)
elif (
current_block['type'] == BlockType.List
and prev_block['type'] == BlockType.List
current_block['type'] == BlockType.LIST
and prev_block['type'] == BlockType.LIST
) or (
current_block['type'] == BlockType.Index
and prev_block['type'] == BlockType.Index
current_block['type'] == BlockType.INDEX
and prev_block['type'] == BlockType.INDEX
):
__merge_2_list_blocks(current_block, prev_block)
......@@ -369,21 +352,21 @@ def __para_merge_page(blocks):
continue
def para_split(pdf_info_dict):
def para_split(page_info_list):
all_blocks = []
for page_num, page in pdf_info_dict.items():
blocks = copy.deepcopy(page['preproc_blocks'])
for page_info in page_info_list:
blocks = copy.deepcopy(page_info['preproc_blocks'])
for block in blocks:
block['page_num'] = page_num
block['page_size'] = page['page_size']
block['page_num'] = page_info['page_idx']
block['page_size'] = page_info['page_size']
all_blocks.extend(blocks)
__para_merge_page(all_blocks)
for page_num, page in pdf_info_dict.items():
page['para_blocks'] = []
for page_info in page_info_list:
page_info['para_blocks'] = []
for block in all_blocks:
if block['page_num'] == page_num:
page['para_blocks'].append(block)
if block['page_num'] == page_info['page_idx']:
page_info['para_blocks'].append(block)
if __name__ == '__main__':
......
import os
import time
import numpy as np
import torch
from .model_init import MineruPipelineModel
from mineru.utils.config_reader import get_device, get_formula_config, get_table_recog_config
from ...utils.pdf_classify import classify
from ...utils.pdf_image_tools import load_images_from_pdf
from loguru import logger
from ...utils.model_utils import get_vram, clean_memory
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
lang=None,
formula_enable=None,
table_enable=None,
):
key = (lang, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(
lang=lang,
formula_enable=formula_enable,
table_enable=table_enable,
)
return self._models[key]
def custom_model_init(
lang=None,
formula_enable=None,
table_enable=None,
):
model_init_start = time.time()
# 从配置文件读取model-dir和device
device = get_device()
formula_config = get_formula_config()
if formula_enable is not None:
formula_config['enable'] = formula_enable
table_config = get_table_recog_config()
if table_enable is not None:
table_config['enable'] = table_enable
model_input = {
'device': device,
'table_config': table_config,
'formula_config': formula_config,
'lang': lang,
}
custom_model = MineruPipelineModel(**model_input)
model_init_cost = time.time() - model_init_start
logger.info(f'model init cost: {model_init_cost}')
return custom_model
def doc_analyze(
pdf_bytes_list,
lang_list,
parse_method: str = 'auto',
formula_enable=None,
table_enable=None,
):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
# 收集所有页面信息
all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
all_image_lists = []
all_pdf_docs = []
ocr_enabled_list = []
for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
# 确定OCR设置
_ocr_enable = False
if parse_method == 'auto':
if classify(pdf_bytes) == 'ocr':
_ocr_enable = True
elif parse_method == 'ocr':
_ocr_enable = True
ocr_enabled_list.append(_ocr_enable)
_lang = lang_list[pdf_idx]
# 收集每个数据集中的页面
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
all_image_lists.append(images_list)
all_pdf_docs.append(pdf_doc)
for page_idx in range(len(images_list)):
img_dict = images_list[page_idx]
all_pages_info.append((
pdf_idx, page_idx,
img_dict['img_pil'], _ocr_enable, _lang,
))
# 准备批处理
images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [
images_with_extra_info[i:i + batch_size]
for i in range(0, len(images_with_extra_info), batch_size)
]
# 执行批处理
results = []
processed_images_count = 0
for index, batch_image in enumerate(batch_images):
processed_images_count += len(batch_image)
logger.info(
f'Batch {index + 1}/{len(batch_images)}: '
f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
)
batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
results.extend(batch_results)
# 构建返回结果
infer_results = []
for _ in range(len(pdf_bytes_list)):
infer_results.append([])
for i, page_info in enumerate(all_pages_info):
pdf_idx, page_idx, pil_img, _, _ = page_info
result = results[i]
page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
page_dict = {'layout_dets': result, 'page_info': page_info_dict}
infer_results[pdf_idx].append(page_dict)
return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
def batch_image_analyze(
images_with_extra_info: list[(np.ndarray, bool, str)],
formula_enable=None,
table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
from .batch_analyze import BatchAnalyze
model_manager = ModelSingleton()
batch_ratio = 1
device = get_device()
if str(device).startswith('npu'):
import torch_npu
if torch_npu.npu.is_available():
torch.npu.set_compile_mode(jit_compile=False)
if str(device).startswith('npu') or str(device).startswith('cuda'):
vram = get_vram(device)
if vram is not None:
gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
if gpu_memory >= 16:
batch_ratio = 16
elif gpu_memory >= 12:
batch_ratio = 8
elif gpu_memory >= 8:
batch_ratio = 4
elif gpu_memory >= 6:
batch_ratio = 2
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
else:
# Default batch_ratio when VRAM can't be determined
batch_ratio = 1
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
results = batch_model(images_with_extra_info)
clean_memory(get_device())
return results
\ No newline at end of file
import enum
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, bbox_distance, bbox_relative_pos,
calculate_iou)
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class PosRelationEnum(enum.Enum):
LEFT = 'left'
RIGHT = 'right'
UP = 'up'
BOTTOM = 'bottom'
ALL = 'all'
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in
from mineru.utils.enum_class import CategoryId, ContentType
class MagicModel:
"""每个函数没有得到元素的时候返回空list."""
def __init__(self, page_model_info: dict, scale: float):
self.__page_model_info = page_model_info
self.__scale = scale
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def __fix_axis(self):
for model_page_info in self.__model_list:
need_remove_list = []
page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs.get_page(page_no)
)
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det.get('bbox') is not None:
# 兼容直接输出bbox的模型数据,如paddle
x0, y0, x1, y1 = layout_det['bbox']
else:
# 兼容直接输出poly的模型数据,如xxx
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [
int(x0 / horizontal_scale_ratio),
int(y0 / vertical_scale_ratio),
int(x1 / horizontal_scale_ratio),
int(y1 / vertical_scale_ratio),
]
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det in layout_dets:
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [
int(x0 / self.__scale),
int(y0 / self.__scale),
int(x1 / self.__scale),
int(y1 / self.__scale),
]
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_low_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det in layout_dets:
if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_high_iou_and_low_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info['layout_dets']
for layout_det1 in layout_dets:
for layout_det2 in layout_dets:
if layout_det1 == layout_det2:
continue
if layout_det1['category_id'] in [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
for layout_det1 in layout_dets:
for layout_det2 in layout_dets:
if layout_det1 == layout_det2:
continue
if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
continue
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: Dataset):
self.__model_list = model_list
self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
footnotes = []
figures = []
tables = []
for obj in self.__page_model_info['layout_dets']:
if obj['category_id'] == 7:
footnotes.append(obj)
elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
......@@ -132,68 +149,6 @@ class MagicModel:
return bbox_distance(bbox1, bbox2)
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list:
footnotes = []
figures = []
tables = []
for obj in model_page_info['layout_dets']:
if obj['category_id'] == 7:
footnotes.append(obj)
elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes):
N = len(bboxes)
keep = [True] * N
......@@ -201,262 +156,14 @@ class MagicModel:
for j in range(N):
if i == j:
continue
if _is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
if is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance_v2(
self,
page_no: int,
subject_category_id: int,
object_category_id: int,
priority_pos: PosRelationEnum,
):
"""_summary_
Args:
page_no (int): _description_
subject_category_id (int): _description_
object_category_id (int): _description_
priority_pos (PosRelationEnum): _description_
Returns:
_type_: _description_
"""
AXIS_MULPLICITY = 0.5
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
M = len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
sub_obj_map_h = {i: [] for i in range(len(subjects))}
dis_by_directions = {
'top': [[-1, float('inf')]] * M,
'bottom': [[-1, float('inf')]] * M,
'left': [[-1, float('inf')]] * M,
'right': [[-1, float('inf')]] * M,
}
for i, obj in enumerate(objects):
l_x_axis, l_y_axis = (
obj['bbox'][2] - obj['bbox'][0],
obj['bbox'][3] - obj['bbox'][1],
)
axis_unit = min(l_x_axis, l_y_axis)
for j, sub in enumerate(subjects):
bbox1, bbox2, _ = _remove_overlap_between_bbox(
objects[i]['bbox'], subjects[j]['bbox']
)
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
if sum([1 if v else 0 for v in flags]) > 1:
continue
if left:
if dis_by_directions['left'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['left'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if right:
if dis_by_directions['right'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['right'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if bottom:
if dis_by_directions['bottom'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['bottom'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if top:
if dis_by_directions['top'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['top'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if (
dis_by_directions['top'][i][1] != float('inf')
and dis_by_directions['bottom'][i][1] != float('inf')
and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
):
RATIO = 3
if (
abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
)
< RATIO * axis_unit
):
if priority_pos == PosRelationEnum.BOTTOM:
sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
else:
sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
continue
if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
'right'
][i][1] != float('inf'):
if dis_by_directions['left'][i][1] != float(
'inf'
) and dis_by_directions['right'][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['left'][i][1]
- dis_by_directions['right'][i][1]
):
left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
'bbox'
]
right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
'bbox'
]
left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
if (
abs(left_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['left'][i][0]
> abs(right_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['right'][i][0]
):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] > dis_by_directions['right'][i][1]:
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] == float('inf'):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = [-1, float('inf')]
if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
'bottom'
][i][1] != float('inf'):
if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
'bottom'
][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
):
top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
top_bottom_x_axis = top_bottom[2] - top_bottom[0]
bottom_top_x_axis = bottom_top[2] - bottom_top[0]
if (
abs(top_bottom_x_axis - l_x_axis)
+ dis_by_directions['bottom'][i][1]
> abs(bottom_top_x_axis - l_x_axis)
+ dis_by_directions['top'][i][1]
):
top_or_bottom = dis_by_directions['top'][i]
else:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] == float('inf'):
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = [-1, float('inf')]
if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
'inf'
):
if AXIS_MULPLICITY * axis_unit >= abs(
left_or_right[1] - top_or_bottom[1]
):
y_axis_bbox = subjects[left_or_right[0]]['bbox']
x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
if (
abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
> abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
/ l_y_axis
):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
if left_or_right[1] > top_or_bottom[1]:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
sub_obj_map_h[left_or_right[0]].append(i)
else:
if left_or_right[1] != float('inf'):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
ret = []
for i in sub_obj_map_h.keys():
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [
{'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
for j in sub_obj_map_h[i]
],
'sub_idx': i,
}
)
return ret
def __tie_up_category_by_distance_v3(
self,
page_no: int,
subject_category_id: int,
object_category_id: int,
priority_pos: PosRelationEnum,
):
subjects = self.__reduct_overlap(
list(
......@@ -464,7 +171,7 @@ class MagicModel:
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
self.__page_model_info['layout_dets'],
),
)
)
......@@ -475,7 +182,7 @@ class MagicModel:
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
self.__page_model_info['layout_dets'],
),
)
)
......@@ -605,13 +312,12 @@ class MagicModel:
return ret
def get_imgs_v2(self, page_no: int):
def get_imgs(self):
with_captions = self.__tie_up_category_by_distance_v3(
page_no, 3, 4, PosRelationEnum.BOTTOM
3, 4
)
with_footnotes = self.__tie_up_category_by_distance_v3(
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
3, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
......@@ -625,12 +331,12 @@ class MagicModel:
ret.append(record)
return ret
def get_tables_v2(self, page_no: int) -> list:
def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3(
page_no, 5, 6, PosRelationEnum.UP
5, 6
)
with_footnotes = self.__tie_up_category_by_distance_v3(
page_no, 5, 7, PosRelationEnum.ALL
5, 7
)
ret = []
for v in with_captions:
......@@ -644,52 +350,31 @@ class MagicModel:
ret.append(record)
return ret
def get_imgs(self, page_no: int):
return self.get_imgs_v2(page_no)
def get_tables(
self, page_no: int
) -> list: # 3个坐标, caption, table主体,table-note
return self.get_tables_v2(page_no)
def get_equations(self, page_no: int) -> list: # 有坐标,也有字
def get_equations(self) -> tuple[list, list, list]: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.EMBEDDING.value, page_no, ['latex']
CategoryId.InlineEquation, ['latex']
)
interline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATED.value, page_no, ['latex']
CategoryId.InterlineEquation_YOLO, ['latex']
)
interline_equations_blocks = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
CategoryId.InterlineEquation_Layout
)
return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ABANDON.value, page_no)
def get_discarded(self) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(CategoryId.Abandon)
return blocks
def get_text_blocks(self, page_no: int) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.PLAIN_TEXT.value, page_no)
def get_text_blocks(self) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(CategoryId.Text)
return blocks
def get_title_blocks(self, page_no: int) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.TITLE.value, page_no)
def get_title_blocks(self) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(CategoryId.Title)
return blocks
def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标
text_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det['category_id'] == '15':
span = {
'bbox': layout_det['bbox'],
'content': layout_det['text'],
}
text_spans.append(span)
return text_spans
def get_all_spans(self, page_no: int) -> list:
def get_all_spans(self) -> list:
def remove_duplicate_spans(spans):
new_spans = []
......@@ -699,8 +384,7 @@ class MagicModel:
return new_spans
all_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info['layout_dets']
layout_dets = self.__page_model_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的"""
# 3: 'image', # 图片
......@@ -713,7 +397,7 @@ class MagicModel:
if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3:
span['type'] = ContentType.Image
span['type'] = ContentType.IMAGE
elif category_id == 5:
# 获取table模型结果
latex = layout_det.get('latex', None)
......@@ -722,50 +406,36 @@ class MagicModel:
span['latex'] = latex
elif html:
span['html'] = html
span['type'] = ContentType.Table
span['type'] = ContentType.TABLE
elif category_id == 13:
span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation
span['type'] = ContentType.INLINE_EQUATION
elif category_id == 14:
span['content'] = layout_det['latex']
span['type'] = ContentType.InterlineEquation
span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == 15:
span['content'] = layout_det['text']
span['type'] = ContentType.Text
span['type'] = ContentType.TEXT
all_spans.append(span)
return remove_duplicate_spans(all_spans)
def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象
page = self.__docs.get_page(page_no).get_page_info()
# 获取当前页的宽高
page_w = page.w
page_h = page.h
return page_w, page_h
def __get_blocks_by_type(
self, type: int, page_no: int, extra_col: list[str] = []
self, category_type: int, extra_col=None
) -> list:
if extra_col is None:
extra_col = []
blocks = []
for page_dict in self.__model_list:
layout_dets = page_dict.get('layout_dets', [])
page_info = page_dict.get('page_info', {})
page_number = page_info.get('page_no', -1)
if page_no != page_number:
continue
for item in layout_dets:
category_id = item.get('category_id', -1)
bbox = item.get('bbox', None)
if category_id == type:
block = {
'bbox': bbox,
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
layout_dets = self.__page_model_info.get('layout_dets', [])
for item in layout_dets:
category_id = item.get('category_id', -1)
bbox = item.get('bbox', None)
if category_id == category_type:
block = {
'bbox': bbox,
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks
def get_model_list(self, page_no):
return self.__model_list[page_no]
import re
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.config_reader import get_latex_delimiter_config
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.post_proc.para_split_v3 import ListLineTag
from mineru.utils.config_reader import get_latex_delimiter_config
from mineru.backend.pipeline.para_split import ListLineTag
from mineru.utils.enum_class import BlockType, ContentType, MakeMode
from mineru.utils.language import detect_lang
def __is_hyphen_at_line_end(line):
......@@ -24,34 +20,7 @@ def __is_hyphen_at_line_end(line):
return bool(re.search(r'[A-Za-z]+-\s*$', line))
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
img_buket_path):
markdown_with_para_and_pagination = []
page_no = 0
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout:
markdown_with_para_and_pagination.append({
'page_no':
page_no,
'md_content':
'',
})
page_no += 1
continue
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown_with_para_and_pagination.append({
'page_no':
page_no,
'md_content':
'\n\n'.join(page_markdown)
})
page_no += 1
return markdown_with_para_and_pagination
def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
def make_blocks_to_markdown(paras_of_layout,
mode,
img_buket_path='',
):
......@@ -59,64 +28,67 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title:
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Image:
if mode == 'nlp':
elif para_type == BlockType.INTERLINE_EQUATION:
if para_block['lines'][0]['spans'][0].get('content', ''):
para_text = merge_para_with_text(para_block)
else:
para_text += f"![]({img_buket_path}/{para_block['lines'][0]['spans'][0]['image_path']})"
elif para_type == BlockType.IMAGE:
if mode == MakeMode.NLP_MD:
continue
elif mode == 'mm':
elif mode == MakeMode.MM_MD:
# 检测是否存在图片脚注
has_image_footnote = any(block['type'] == BlockType.ImageFootnote for block in para_block['blocks'])
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
if has_image_footnote:
for block in para_block['blocks']: # 1st.拼image_caption
if block['type'] == BlockType.ImageCaption:
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼image_body
if block['type'] == BlockType.ImageBody:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼image_footnote
if block['type'] == BlockType.ImageFootnote:
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_text += ' \n' + merge_para_with_text(block)
else:
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.ImageBody:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption:
if block['type'] == BlockType.IMAGE_CAPTION:
para_text += ' \n' + merge_para_with_text(block)
elif para_type == BlockType.Table:
if mode == 'nlp':
elif para_type == BlockType.TABLE:
if mode == MakeMode.NLP_MD:
continue
elif mode == 'mm':
elif mode == MakeMode.MM_MD:
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption:
if block['type'] == BlockType.TABLE_CAPTION:
para_text += merge_para_with_text(block) + ' \n'
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Table:
if span['type'] == ContentType.TABLE:
# if processed by table model
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote:
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' '
if para_text.strip() == '':
......@@ -128,19 +100,6 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
return page_markdown
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
en_length = sum(len(match) for match in en_matches)
if len(text) > 0:
if en_length / len(text) >= 0.5:
return 'en'
else:
return 'unknown'
else:
return 'empty'
def full_to_half(text: str) -> str:
"""Convert full-width characters to half-width characters using code point manipulation.
......@@ -178,7 +137,7 @@ def merge_para_with_text(para_block):
block_text = ''
for line in para_block['lines']:
for span in line['spans']:
if span['type'] in [ContentType.Text]:
if span['type'] in [ContentType.TEXT]:
span['content'] = full_to_half(span['content'])
block_text += span['content']
block_lang = detect_lang(block_text)
......@@ -193,11 +152,11 @@ def merge_para_with_text(para_block):
span_type = span['type']
content = ''
if span_type == ContentType.Text:
content = ocr_escape_special_markdown_char(span['content'])
elif span_type == ContentType.InlineEquation:
if span_type == ContentType.TEXT:
content = escape_special_markdown_char(span['content'])
elif span_type == ContentType.INLINE_EQUATION:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.InterlineEquation:
elif span_type == ContentType.INTERLINE_EQUATION:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
content = content.strip()
......@@ -206,36 +165,34 @@ def merge_para_with_text(para_block):
langs = ['zh', 'ja', 'ko']
# logger.info(f'block_lang: {block_lang}, content: {content}')
if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔,但是如果是行内公式结尾,还是要加空格
if j == len(line['spans']) - 1 and span_type not in [ContentType.InlineEquation]:
if j == len(line['spans']) - 1 and span_type not in [ContentType.INLINE_EQUATION]:
para_text += content
else:
para_text += f'{content} '
else:
if span_type in [ContentType.Text, ContentType.InlineEquation]:
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
if j == len(line['spans'])-1 and span_type == ContentType.Text and __is_hyphen_at_line_end(content):
if j == len(line['spans'])-1 and span_type == ContentType.TEXT and __is_hyphen_at_line_end(content):
para_text += content[:-1]
else: # 西方文本语境下 content间需要空格分隔
para_text += f'{content} '
elif span_type == ContentType.InterlineEquation:
elif span_type == ContentType.INTERLINE_EQUATION:
para_text += content
else:
continue
# 连写字符拆分
# para_text = __replace_ligatures(para_text)
return para_text
def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason=None):
def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
para_type = para_block['type']
para_content = {}
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.Title:
elif para_type == BlockType.TITLE:
para_content = {
'type': 'text',
'text': merge_para_with_text(para_block),
......@@ -243,32 +200,34 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
title_level = get_title_level(para_block)
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.InterlineEquation:
elif para_type == BlockType.INTERLINE_EQUATION:
para_content = {
'type': 'equation',
'text': merge_para_with_text(para_block),
'text_format': 'latex',
'img_path': f"{img_buket_path}/{para_block['lines'][0]['spans'][0].get('image_path', '')}",
}
elif para_type == BlockType.Image:
if para_block['lines'][0]['spans'][0].get('content', ''):
para_content['text'] = merge_para_with_text(para_block)
para_content['text_format'] = 'latex'
elif para_type == BlockType.IMAGE:
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.ImageBody:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
if span['type'] == ContentType.IMAGE:
if span.get('image_path', ''):
para_content['img_path'] = join_path(img_buket_path, span['image_path'])
if block['type'] == BlockType.ImageCaption:
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION:
para_content['img_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.ImageFootnote:
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content['img_footnote'].append(merge_para_with_text(block))
elif para_type == BlockType.Table:
elif para_type == BlockType.TABLE:
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
for block in para_block['blocks']:
if block['type'] == BlockType.TableBody:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Table:
if span['type'] == ContentType.TABLE:
if span.get('latex', ''):
para_content['table_body'] = f"{span['latex']}"
......@@ -276,71 +235,43 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
para_content['table_body'] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = join_path(img_buket_path, span['image_path'])
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TableCaption:
if block['type'] == BlockType.TABLE_CAPTION:
para_content['table_caption'].append(merge_para_with_text(block))
if block['type'] == BlockType.TableFootnote:
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content['table_footnote'].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx
if drop_reason is not None:
para_content['drop_reason'] = drop_reason
return para_content
def union_make(pdf_info_dict: list,
make_mode: str,
drop_mode: str,
img_buket_path: str = '',
):
output_content = []
for page_info in pdf_info_dict:
drop_reason_flag = False
drop_reason = None
if page_info.get('need_drop', False):
drop_reason = page_info.get('drop_reason')
if drop_mode == DropMode.NONE:
pass
elif drop_mode == DropMode.NONE_WITH_REASON:
drop_reason_flag = True
elif drop_mode == DropMode.WHOLE_PDF:
raise Exception((f'drop_mode is {DropMode.WHOLE_PDF} ,'
f'drop_reason is {drop_reason}'))
elif drop_mode == DropMode.SINGLE_PAGE:
logger.warning((f'drop_mode is {DropMode.SINGLE_PAGE} ,'
f'drop_reason is {drop_reason}'))
continue
else:
raise Exception('drop_mode can not be null')
paras_of_layout = page_info.get('para_blocks')
page_idx = page_info.get('page_idx')
if not paras_of_layout:
continue
if make_mode == MakeMode.MM_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = make_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.NLP_MD:
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'nlp')
output_content.extend(page_markdown)
elif make_mode == MakeMode.STANDARD_FORMAT:
elif make_mode == MakeMode.CONTENT_LIST:
for para_block in paras_of_layout:
if drop_reason_flag:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
else:
para_content = para_to_standard_format_v2(
para_block, img_buket_path, page_idx)
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
output_content.append(para_content)
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
return '\n\n'.join(output_content)
elif make_mode == MakeMode.STANDARD_FORMAT:
elif make_mode == MakeMode.CONTENT_LIST:
return output_content
else:
logger.error(f"Unsupported make mode: {make_mode}")
return None
def get_title_level(block):
......@@ -349,4 +280,15 @@ def get_title_level(block):
title_level = 4
elif title_level < 1:
title_level = 0
return title_level
\ No newline at end of file
return title_level
def escape_special_markdown_char(content):
"""
转义正文里对markdown语法有特殊意义的字符
"""
special_chars = ["*", "`", "~", "$"]
for char in special_chars:
content = content.replace(char, "\\" + char)
return content
\ No newline at end of file
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncIterable, Iterable, List, Optional, Union
DEFAULT_SYSTEM_PROMPT = (
"A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
)
DEFAULT_USER_PROMPT = "Document Parsing:"
DEFAULT_TEMPERATURE = 0.0
DEFAULT_TOP_P = 0.01
DEFAULT_TOP_K = 1
DEFAULT_REPETITION_PENALTY = 1.0
DEFAULT_PRESENCE_PENALTY = 0.0
DEFAULT_NO_REPEAT_NGRAM_SIZE = 100
DEFAULT_MAX_NEW_TOKENS = 16384
class BasePredictor(ABC):
system_prompt = DEFAULT_SYSTEM_PROMPT
def __init__(
self,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
) -> None:
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
self.presence_penalty = presence_penalty
self.no_repeat_ngram_size = no_repeat_ngram_size
self.max_new_tokens = max_new_tokens
@abstractmethod
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str: ...
@abstractmethod
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]: ...
@abstractmethod
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]: ...
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
return await asyncio.to_thread(
self.predict,
image,
prompt,
temperature,
top_p,
top_k,
repetition_penalty,
presence_penalty,
no_repeat_ngram_size,
max_new_tokens,
)
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
return await asyncio.to_thread(
self.batch_predict,
images,
prompts,
temperature,
top_p,
top_k,
repetition_penalty,
presence_penalty,
no_repeat_ngram_size,
max_new_tokens,
)
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def synced_predict():
for chunk in self.stream_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
):
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
asyncio.create_task(
asyncio.to_thread(synced_predict),
)
while True:
chunk = await queue.get()
if chunk is None:
return
assert isinstance(chunk, str)
yield chunk
def build_prompt(self, prompt: str) -> str:
if prompt.startswith("<|im_start|>"):
return prompt
if not prompt:
prompt = DEFAULT_USER_PROMPT
return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
# Modify here. We add <|box_start|> at the end of the prompt to force the model to generate bounding box.
# if "Document OCR" in prompt:
# return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n<|box_start|>"
# else:
# return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
def close(self):
pass
from io import BytesIO
from typing import Iterable, List, Optional, Union
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer, BitsAndBytesConfig
from ...model.vlm_hf_model import Mineru2QwenForCausalLM
from ...model.vlm_hf_model.image_processing_mineru2 import process_images
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .utils import load_resource
class HuggingfacePredictor(BasePredictor):
def __init__(
self,
model_path: str,
device_map="auto",
device="cuda",
torch_dtype="auto",
load_in_8bit=False,
load_in_4bit=False,
use_flash_attn=False,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
**kwargs,
):
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs["device_map"] = {"": device}
if load_in_8bit:
kwargs["load_in_8bit"] = True
elif load_in_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
else:
kwargs["torch_dtype"] = torch_dtype
if use_flash_attn:
kwargs["attn_implementation"] = "flash_attention_2"
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Mineru2QwenForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs,
)
setattr(self.model.config, "_name_or_path", model_path)
self.model.eval()
vision_tower = self.model.get_model().vision_tower
if device_map != "auto":
vision_tower.to(device=device_map, dtype=self.model.dtype)
self.image_processor = vision_tower.image_processor
self.eos_token_id = self.model.config.eos_token_id
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
**kwargs,
) -> str:
prompt = self.build_prompt(prompt)
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
do_sample = (temperature > 0.0) and (top_k > 1)
generate_kwargs = {
"repetition_penalty": repetition_penalty,
"no_repeat_ngram_size": no_repeat_ngram_size,
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
}
if do_sample:
generate_kwargs["temperature"] = temperature
generate_kwargs["top_p"] = top_p
generate_kwargs["top_k"] = top_k
if isinstance(image, str):
image = load_resource(image)
image_obj = Image.open(BytesIO(image))
image_tensor = process_images([image_obj], self.image_processor, self.model.config)
image_tensor = image_tensor[0].unsqueeze(0)
image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
image_sizes = [[*image_obj.size]]
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device=self.model.device)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
use_cache=True,
**generate_kwargs,
**kwargs,
)
# Remove the last token if it is the eos_token_id
if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
output_ids = output_ids[:, :-1]
output = self.tokenizer.batch_decode(
output_ids,
skip_special_tokens=False,
)[0].strip()
return output
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, # not supported by hf
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
**kwargs,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
outputs = []
for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
output = self.predict(
image,
prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
**kwargs,
)
outputs.append(output)
return outputs
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .sglang_client_predictor import SglangClientPredictor
hf_loaded = False
try:
from .hf_predictor import HuggingfacePredictor
hf_loaded = True
except ImportError as e:
logger.warning("hf is not installed. If you are not using huggingface, you can ignore this warning.")
engine_loaded = False
try:
from sglang.srt.server_args import ServerArgs
from .sglang_engine_predictor import SglangEnginePredictor
engine_loaded = True
except Exception as e:
logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
def get_predictor(
backend: str = "sglang-client",
model_path: str | None = None,
server_url: str | None = None,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
http_timeout: int = 600,
**kwargs,
) -> BasePredictor:
start_time = time.time()
if backend == "huggingface":
if not model_path:
raise ValueError("model_path must be provided for huggingface backend.")
if not hf_loaded:
raise ImportError(
"transformers is not installed, so huggingface backend cannot be used. "
"If you need to use huggingface backend, please install transformers first."
)
predictor = HuggingfacePredictor(
model_path=model_path,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
**kwargs,
)
elif backend == "sglang-engine":
if not model_path:
raise ValueError("model_path must be provided for sglang-engine backend.")
if not engine_loaded:
raise ImportError(
"sglang is not installed, so sglang-engine backend cannot be used. "
"If you need to use sglang-engine backend for inference, "
"please install sglang[all]==0.4.6.post4 or a newer version."
)
predictor = SglangEnginePredictor(
server_args=ServerArgs(model_path, **kwargs),
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
elif backend == "sglang-client":
if not server_url:
raise ValueError("server_url must be provided for sglang-client backend.")
predictor = SglangClientPredictor(
server_url=server_url,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
http_timeout=http_timeout,
)
else:
raise ValueError(f"Unsupported backend: {backend}. Supports: huggingface, sglang-engine, sglang-client.")
elapsed = round(time.time() - start_time, 2)
logger.info(f"get_predictor cost: {elapsed}s")
return predictor
import asyncio
import json
import re
from base64 import b64encode
from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
import httpx
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .utils import aio_load_resource, load_resource
class SglangClientPredictor(BasePredictor):
def __init__(
self,
server_url: str,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
http_timeout: int = 600,
) -> None:
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
self.http_timeout = http_timeout
base_url = self.get_base_url(server_url)
self.check_server_health(base_url)
self.model_path = self.get_model_path(base_url)
self.server_url = f"{base_url}/generate"
@staticmethod
def get_base_url(server_url: str) -> str:
matched = re.match(r"^(https?://[^/]+)", server_url)
if not matched:
raise ValueError(f"Invalid server URL: {server_url}")
return matched.group(1)
def check_server_health(self, base_url: str):
try:
response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
except httpx.ConnectError:
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
if response.status_code != 200:
raise RuntimeError(
f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
)
def get_model_path(self, base_url: str) -> str:
try:
response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
except httpx.ConnectError:
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
if response.status_code != 200:
raise RuntimeError(
f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
)
return response.json()["model_path"]
def build_sampling_params(
self,
temperature: Optional[float],
top_p: Optional[float],
top_k: Optional[int],
repetition_penalty: Optional[float],
presence_penalty: Optional[float],
no_repeat_ngram_size: Optional[int],
max_new_tokens: Optional[int],
) -> dict:
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
return {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
def build_request_body(
self,
image: bytes,
prompt: str,
sampling_params: dict,
) -> dict:
image_base64 = b64encode(image).decode("utf-8")
return {
"text": prompt,
"image_data": image_base64,
"sampling_params": sampling_params,
"modalities": ["image"],
}
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
response_body = response.json()
return response_body["text"]
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> List[str]:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
task = self.aio_batch_predict(
images=images,
prompts=prompts,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
max_concurrency=max_concurrency,
)
if loop is not None:
return loop.run_until_complete(task)
else:
return asyncio.run(task)
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
request_body["stream"] = True
with httpx.stream(
"POST",
self.server_url,
json=request_body,
timeout=self.http_timeout,
) as response:
pos = 0
for chunk in response.iter_lines():
if not (chunk or "").startswith("data:"):
continue
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
# meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
async_client: Optional[httpx.AsyncClient] = None,
) -> str:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = await aio_load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
if async_client is None:
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
response = await client.post(self.server_url, json=request_body)
response_body = response.json()
else:
response = await async_client.post(self.server_url, json=request_body)
response_body = response.json()
return response_body["text"]
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
semaphore = asyncio.Semaphore(max_concurrency)
outputs = [""] * len(images)
async def predict_with_semaphore(
idx: int,
image: str | bytes,
prompt: str,
async_client: httpx.AsyncClient,
):
async with semaphore:
output = await self.aio_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
async_client=async_client,
)
outputs[idx] = output
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
tasks = []
for idx, (prompt, image) in enumerate(zip(prompts, images)):
tasks.append(predict_with_semaphore(idx, image, prompt, client))
await asyncio.gather(*tasks)
return outputs
async def aio_batch_predict_as_iter(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> AsyncIterable[Tuple[int, str]]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
semaphore = asyncio.Semaphore(max_concurrency)
async def predict_with_semaphore(
idx: int,
image: str | bytes,
prompt: str,
async_client: httpx.AsyncClient,
):
async with semaphore:
output = await self.aio_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
async_client=async_client,
)
return (idx, output)
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
pending: Set[asyncio.Task[Tuple[int, str]]] = set()
for idx, (prompt, image) in enumerate(zip(prompts, images)):
pending.add(
asyncio.create_task(
predict_with_semaphore(idx, image, prompt, client),
)
)
while len(pending) > 0:
done, pending = await asyncio.wait(
pending,
return_when=asyncio.FIRST_COMPLETED,
)
for task in done:
yield task.result()
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = await aio_load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
request_body["stream"] = True
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
async with client.stream(
"POST",
self.server_url,
json=request_body,
) as response:
pos = 0
async for chunk in response.aiter_lines():
if not (chunk or "").startswith("data:"):
continue
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
# meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text
from base64 import b64encode
from typing import AsyncIterable, Iterable, List, Optional, Union
from sglang.srt.server_args import ServerArgs
from ...model.vlm_sglang_model.engine import BatchEngine
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
class SglangEnginePredictor(BasePredictor):
def __init__(
self,
server_args: ServerArgs,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
) -> None:
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
self.engine = BatchEngine(server_args=server_args)
def load_image_string(self, image: str | bytes) -> str:
if not isinstance(image, (str, bytes)):
raise ValueError("Image must be a string or bytes.")
if isinstance(image, bytes):
return b64encode(image).decode("utf-8")
if image.startswith("file://"):
return image[len("file://") :]
return image
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
return self.batch_predict(
[image], # type: ignore
[prompt],
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)[0]
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
prompts = [self.build_prompt(prompt) for prompt in prompts]
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
sampling_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
image_strings = [self.load_image_string(img) for img in images]
output = self.engine.generate(
prompt=prompts,
image_data=image_strings,
sampling_params=sampling_params,
)
return [item["text"] for item in output]
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
output = await self.aio_batch_predict(
[image], # type: ignore
[prompt],
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
return output[0]
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
prompts = [self.build_prompt(prompt) for prompt in prompts]
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
sampling_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
image_strings = [self.load_image_string(img) for img in images]
output = await self.engine.async_generate(
prompt=prompts,
image_data=image_strings,
sampling_params=sampling_params,
)
ret = []
for item in output: # type: ignore
ret.append(item["text"])
return ret
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
def close(self):
self.engine.shutdown()
import re
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.version import __version__
def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
"""将token转换为页面信息"""
# 解析token,提取坐标和类型
# 假设token格式为:<|box_start|>x0 y0 x1 y1<|box_end|><|ref_start|>type<|ref_end|><|md_start|>content<|md_end|>
# 这里需要根据实际的token格式进行解析
# 提取所有完整块,每个块从<|box_start|>开始到<|md_end|>或<|im_end|>结束
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
width, height = map(int, page.get_size())
magic_model = MagicModel(token, width, height)
image_blocks = magic_model.get_image_blocks()
table_blocks = magic_model.get_table_blocks()
title_blocks = magic_model.get_title_blocks()
text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks()
all_spans = magic_model.get_all_spans()
# 对image/table/interline_equation的span截图
for span in all_spans:
if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
page_blocks = []
page_blocks.extend([*image_blocks, *table_blocks, *title_blocks, *text_blocks, *interline_equation_blocks])
# 对page_blocks根据index的值进行排序
page_blocks.sort(key=lambda x: x["index"])
page_info = {"para_blocks": page_blocks, "discarded_blocks": [], "page_size": [width, height], "page_idx": page_index}
return page_info
def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
for index, token in enumerate(token_list):
page = pdf_doc[index]
image_dict = images_list[index]
page_info = token_to_page_info(token, image_dict, page, image_writer, index)
middle_json["pdf_info"].append(page_info)
# 关闭pdf文档
pdf_doc.close()
return middle_json
if __name__ == "__main__":
output = r"<|box_start|>088 119 472 571<|box_end|><|ref_start|>image<|ref_end|><|md_start|>![]('img_url')<|md_end|>\n<|box_start|>079 582 482 608<|box_end|><|ref_start|>image_caption<|ref_end|><|md_start|>Fig. 2. (a) Schematic of the change in the FDC over time, and (b) definition of model parameters.<|md_end|>\n<|box_start|>079 624 285 638<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.2. Zero flow day analysis<|md_end|>\n<|box_start|>079 656 482 801<|box_end|><|ref_start|>text<|ref_end|><|md_start|>A notable feature of Fig. 1 is the increase in the number of zero flow days. A similar approach to Eq. (2), using an inverse sigmoidal function was employed to assess the impact of afforestation on the number of zero flow days per year \((N_{\mathrm{zero}})\). In this case, the left hand side of Eq. (2) is replaced by \(N_{\mathrm{zero}}\) and \(b\) and \(S\) are constrained to negative as \(N_{\mathrm{zero}}\) decreases as rainfall increases, and increases with plantation growth:<|md_end|>\n<|box_start|>076 813 368 853<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nN_{\mathrm{zero}}=a+b(\Delta P)+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}}{S}\right)}\n\]<|md_end|>\n<|box_start|>079 865 482 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For the average pre-treatment condition \(\Delta P=0\) and \(T=0\), \(N_{\mathrm{zero}}\) approximately equals \(a\). \(Y\) gives<|md_end|>\n<|box_start|>525 119 926 215<|box_end|><|ref_start|>text<|ref_end|><|md_start|>the magnitude of change in zero flow days due to afforestation, and \(S\) describes the shape of the response. For the average climate condition \(\Delta P=0\), \(a+Y\) becomes the number of zero flow days when the new equilibrium condition under afforestation is reached.<|md_end|>\n<|box_start|>525 240 704 253<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.3. Statistical analyses<|md_end|>\n<|box_start|>525 271 926 368<|box_end|><|ref_start|>text<|ref_end|><|md_start|>The coefficient of efficiency \((E)\) (Nash and Sutcliffe, 1970; Chiew and McMahon, 1993; Legates and McCabe, 1999) was used as the 'goodness of fit' measure to evaluate the fit between observed and predicted flow deciles (2) and zero flow days (3). \(E\) is given by:<|md_end|>\n<|box_start|>520 375 735 415<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nE=1.0-\frac{\sum_{i=1}^{N}(O_{i}-P_{i})^{2}}{\sum_{i=1}^{N}(O_{i}-\bar{O})^{2}}\n\]<|md_end|>\n<|box_start|>525 424 926 601<|box_end|><|ref_start|>text<|ref_end|><|md_start|>where \(O\) are observed data, \(P\) are predicted values, and \(\bar{O}\) is the mean for the entire period. \(E\) is unity minus the ratio of the mean square error to the variance in the observed data, and ranges from \(-\infty\) to 1.0. Higher values indicate greater agreement between observed and predicted data as per the coefficient of determination \((r^{2})\). \(E\) is used in preference to \(r^{2}\) in evaluating hydrologic modelling because it is a measure of the deviation from the 1:1 line. As \(E\) is always \(<r^{2}\) we have arbitrarily considered \(E>0.7\) to indicate adequate model fits.<|md_end|>\n<|box_start|>525 603 926 731<|box_end|><|ref_start|>text<|ref_end|><|md_start|>It is important to assess the significance of the model parameters to check the model assumptions that rainfall and forest age are driving changes in the FDC. The model (2) was split into simplified forms, where only the rainfall or time terms were included by setting \(b=0\), as shown in Eq. (5), or \(Y=0\) as shown in Eq. (6). The component models (5) and (6) were then tested against the complete model, (2).<|md_end|>\n<|box_start|>520 739 735 778<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}^{\prime}}{S}\right)}\n\]<|md_end|>\n<|box_start|>525 787 553 799<|box_end|><|ref_start|>text<|ref_end|><|md_start|>and<|md_end|>\n<|box_start|>520 807 646 825<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+b\Delta P\n\]<|md_end|>\n<|box_start|>525 833 926 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For both the flow duration curve analysis and zero flow days analysis, a \(t\)-test was then performed to test whether (5) and (6) were significantly different to (2). A critical value of \(t\) exceeding the calculated \(t\)-value<|md_end|><|im_end|>"
p_info = token_to_page_info(output)
# 将blocks 转换为json文本
import json
json_str = json.dumps(p_info, ensure_ascii=False, indent=4)
print(json_str)
import os
import re
from base64 import b64decode
import httpx
_timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
_file_exts = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".pdf")
_data_uri_regex = re.compile(r"^data:[^;,]+;base64,")
def load_resource(uri: str) -> bytes:
if uri.startswith("http://") or uri.startswith("https://"):
response = httpx.get(uri, timeout=_timeout)
return response.content
if uri.startswith("file://"):
with open(uri[len("file://") :], "rb") as file:
return file.read()
if uri.lower().endswith(_file_exts):
with open(uri, "rb") as file:
return file.read()
if re.match(_data_uri_regex, uri):
return b64decode(uri.split(",")[1])
return b64decode(uri)
async def aio_load_resource(uri: str) -> bytes:
if uri.startswith("http://") or uri.startswith("https://"):
async with httpx.AsyncClient(timeout=_timeout) as client:
response = await client.get(uri)
return response.content
if uri.startswith("file://"):
with open(uri[len("file://") :], "rb") as file:
return file.read()
if uri.lower().endswith(_file_exts):
with open(uri, "rb") as file:
return file.read()
if re.match(_data_uri_regex, uri):
return b64decode(uri.split(",")[1])
return b64decode(uri)
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from ...data.data_reader_writer import DataWriter
from mineru.utils.pdf_image_tools import load_images_from_pdf
from .base_predictor import BasePredictor
from .predictor import get_predictor
from .token_to_middle_json import result_to_middle_json
from ...utils.enum_class import ModelPath
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
backend: str,
model_path: str | None,
server_url: str | None,
) -> BasePredictor:
key = (backend,)
if key not in self._models:
self._models[key] = get_predictor(
backend=backend,
model_path=model_path,
server_url=server_url,
)
return self._models[key]
def doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: BasePredictor | None = None,
backend="huggingface",
model_path=ModelPath.vlm_root_hf,
server_url: str | None = None,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
# load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
# load_images_time = round(time.time() - load_images_start, 2)
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
infer_start = time.time()
results = predictor.batch_predict(images=images_base64_list)
infer_time = round(time.time() - infer_start, 2)
logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json, results
async def aio_doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: BasePredictor | None = None,
backend="huggingface",
model_path=ModelPath.vlm_root_hf,
server_url: str | None = None,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
load_images_time = round(time.time() - load_images_start, 2)
logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
infer_start = time.time()
results = await predictor.aio_batch_predict(images=images_base64_list)
infer_time = round(time.time() - infer_start, 2)
logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment