Commit 0fc1daac authored by luopl's avatar luopl
Browse files

Initial commit

parents
import cv2
from loguru import logger
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from .model_init import AtomModelSingleton
from ...utils.config_reader import get_formula_enable, get_table_enable
from ...utils.model_utils import crop_img, get_res_list_from_layout_res
from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
YOLO_LAYOUT_BASE_BATCH_SIZE = 8
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
self.batch_ratio = batch_ratio
self.formula_enable = get_formula_enable(formula_enable)
self.table_enable = get_table_enable(table_enable)
self.model_manager = model_manager
self.enable_ocr_det_batch = enable_ocr_det_batch
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
return []
images_layout_res = []
self.model = self.model_manager.get_model(
lang=None,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
atom_model_manager = AtomModelSingleton()
images = [image for image, _, _ in images_with_extra_info]
# doclayout_yolo
layout_images = []
for image_index, image in enumerate(images):
layout_images.append(image)
images_layout_res += self.model.layout_model.batch_predict(
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
if self.formula_enable:
# 公式检测
images_mfd_res = self.model.mfd_model.batch_predict(
images, MFD_BASE_BATCH_SIZE
)
# 公式识别
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
mfr_count = 0
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
# 清理显存
# clean_vram(self.model.device, vram_threshold=8)
ocr_res_list_all_page = []
table_res_list_all_page = []
for index in range(len(images)):
_, ocr_enable, _lang = images_with_extra_info[index]
layout_res = images_layout_res[index]
pil_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
'lang':_lang,
'ocr_enable':ocr_enable,
'pil_img':pil_img,
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
for table_res in table_res_list:
table_img, _ = crop_img(table_res, pil_img)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':table_img,
})
# OCR检测处理
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info = []
for ocr_res_list_dict in ocr_res_list_all_page:
_lang = ocr_res_list_dict['lang']
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# BGR转换
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
all_cropped_images_info.append((
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
))
# 按语言分组
lang_groups = defaultdict(list)
for crop_info in all_cropped_images_info:
lang = crop_info[5]
lang_groups[lang].append(crop_info)
# 对每种语言按分辨率分组并批处理
for lang, lang_crop_list in lang_groups.items():
if not lang_crop_list:
continue
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
# 获取OCR模型
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.3,
lang=lang
)
# 按分辨率分组并同时完成padding
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
normalized_w = ((w + 32) // 32) * 32
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
# 对每个分辨率组进行批处理
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
target_h = ((max_h + 32 - 1) // 32) * 32
target_w = ((max_w + 32 - 1) // 32) * 32
# 对所有图像进行padding到统一尺寸
batch_images = []
for crop_info in group_crops:
img = crop_info[0]
h, w = img.shape[:2]
# 创建目标尺寸的白色背景
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
# 将原图像粘贴到左上角
padded_img[:h, :w] = img
batch_images.append(padded_img)
# 批处理检测
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
# 处理批处理结果
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
if dt_boxes is not None and len(dt_boxes) > 0:
# 直接应用原始OCR流程中的关键处理步骤
from mineru.utils.ocr_utils import (
merge_det_boxes, update_det_boxes, sorted_boxes
)
# 1. 排序检测框
if len(dt_boxes) > 0:
dt_boxes_sorted = sorted_boxes(dt_boxes)
else:
dt_boxes_sorted = []
# 2. 合并相邻检测框
if dt_boxes_sorted:
dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
else:
dt_boxes_merged = []
# 3. 根据公式位置更新检测框(关键步骤!)
if dt_boxes_merged and adjusted_mfdetrec_res:
dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
else:
dt_boxes_final = dt_boxes_merged
# 构造OCR结果格式
ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# 表格识别 table recognition
if self.table_enable:
for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_dict['lang']
table_model = atom_model_manager.get_atom_model(
atom_model_name='table',
lang=_lang,
)
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
if expected_ending:
table_res_dict['table_res']['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
'table recognition processing fails, not get html return'
)
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
lang = layout_res_item['lang']
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(f"{ocr_score:.3f}")
if ocr_score < OcrConfidence.min_confidence:
layout_res_item['category_id'] = 16
total_processed += len(img_crop_list)
return images_layout_res
import os
import torch
from loguru import logger
from .model_list import AtomicModel
from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from ...model.table.rapid_table import RapidTableModel
from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import auto_download_and_get_model_root_path
def table_model_init(lang=None):
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang
)
table_model = RapidTableModel(ocr_engine)
return table_model
def mfd_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model
def mfr_model_init(weight_dir, device='cpu'):
mfr_model = UnimernetModel(weight_dir, device)
return mfr_model
def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith('npu'):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device)
return model
def ocr_model_init(det_db_box_thresh=0.3,
lang=None,
use_dilation=True,
det_db_unclip_ratio=1.8,
):
if lang is not None and lang != '':
model = PytorchPaddleOCR(
det_db_box_thresh=det_db_box_thresh,
lang=lang,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
)
else:
model = PytorchPaddleOCR(
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
)
return model
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get('lang', None)
table_model_name = kwargs.get('table_model_name', None)
if atom_model_name in [AtomicModel.OCR]:
key = (atom_model_name, lang)
elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name, lang)
else:
key = atom_model_name
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
atom_model = doclayout_yolo_model_init(
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get('mfr_weight_dir'),
kwargs.get('device')
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get('det_db_box_thresh'),
kwargs.get('lang'),
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get('lang'),
)
else:
logger.error('model name not allow')
exit(1)
if atom_model is None:
logger.error('model init failed')
exit(1)
else:
return atom_model
class MineruPipelineModel:
def __init__(self, **kwargs):
self.formula_config = kwargs.get('formula_config')
self.apply_formula = self.formula_config.get('enable', True)
self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get('enable', True)
self.lang = kwargs.get('lang', None)
self.device = kwargs.get('device', 'cpu')
logger.info(
'DocAnalysis init, this may take some times......'
)
atom_model_manager = AtomModelSingleton()
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
device=self.device,
)
# 初始化layout模型
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str(
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
),
device=self.device,
)
# 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
lang=self.lang,
)
logger.info('DocAnalysis init done!')
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from tqdm import tqdm
from mineru.utils.config_reader import get_device, get_llm_aided_config, get_formula_enable
from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.backend.pipeline.para_split import para_split
from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
from mineru.utils.block_sort import sort_blocks_by_bbox
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import ContentType
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.model_utils import clean_memory
from mineru.backend.pipeline.pipeline_magic_model import MagicModel
from mineru.utils.ocr_utils import OcrConfidence
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, txt_spans_extract
from mineru.version import __version__
from mineru.utils.hash_utils import str_md5
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
page_w, page_h = map(int, page.get_size())
magic_model = MagicModel(page_model_info, scale)
"""从magic_model对象中获取后面会用到的区块信息"""
discarded_blocks = magic_model.get_discarded()
text_blocks = magic_model.get_text_blocks()
title_blocks = magic_model.get_title_blocks()
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
img_groups = magic_model.get_imgs()
table_groups = magic_model.get_tables()
"""对image和table的区块分组"""
img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
)
"""获取所有的spans信息"""
spans = magic_model.get_all_spans()
"""某些图可能是文本块,通过简单的规则判断一下"""
if len(maybe_text_image_blocks) > 0:
for block in maybe_text_image_blocks:
span_in_block_list = []
for span in spans:
if span['type'] == 'text' and calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7:
span_in_block_list.append(span)
if len(span_in_block_list) > 0:
# span_in_block_list中所有bbox的面积之和
spans_area = sum((span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1]) for span in span_in_block_list)
# 求ocr_res_area和res的面积的比值
block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
if block_area > 0:
ratio = spans_area / block_area
if ratio > 0.25 and ocr_enable:
# 移除block的group_id
block.pop('group_id', None)
# 符合文本图的条件就把块加入到文本块列表中
text_blocks.append(block)
else:
# 如果不符合文本图的条件,就把块加回到图片块列表中
img_body_blocks.append(block)
else:
img_body_blocks.append(block)
"""将所有区块的bbox整理到一起"""
if formula_enabled:
interline_equation_blocks = []
if len(interline_equation_blocks) > 0:
for block in interline_equation_blocks:
spans.append({
"type": ContentType.INTERLINE_EQUATION,
'score': block['score'],
"bbox": block['bbox'],
})
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else:
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
if ocr_enable:
pass
else:
"""使用新版本的混合ocr方案."""
spans = txt_spans_extract(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
"""如果当前页面没有有效的bbox则跳过"""
if len(all_bboxes) == 0:
return None
"""对image/table/interline_equation截图"""
for span in spans:
if span['type'] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(
span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
)
"""span填充进block"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
"""对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans)
"""同一行被断开的titile合并"""
# merge_title_blocks(fix_blocks)
"""对block进行排序"""
sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
"""构造page_info"""
page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
return page_info
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False, formula_enabled=True):
middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__}
formula_enabled = get_formula_enable(formula_enabled)
for page_index, page_model_info in tqdm(enumerate(model_list), total=len(model_list), desc="Processing pages"):
page = pdf_doc[page_index]
image_dict = images_list[page_index]
page_info = page_model_info_to_page_info(
page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable, formula_enabled=formula_enabled
)
if page_info is None:
page_w, page_h = map(int, page.get_size())
page_info = make_page_info_dict([], page_index, page_w, page_h, [])
middle_json["pdf_info"].append(page_info)
"""后置ocr处理"""
need_ocr_list = []
img_crop_list = []
text_block_list = []
for page_info in middle_json["pdf_info"]:
for block in page_info['preproc_blocks']:
if block['type'] in ['table', 'image']:
for sub_block in block['blocks']:
if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
text_block_list.append(sub_block)
elif block['type'] in ['text', 'title']:
text_block_list.append(block)
for block in page_info['discarded_blocks']:
text_block_list.append(block)
for block in text_block_list:
for line in block['lines']:
for span in line['spans']:
if 'np_img' in span:
need_ocr_list.append(span)
img_crop_list.append(span['np_img'])
span.pop('np_img')
if len(img_crop_list) > 0:
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
assert len(ocr_res_list) == len(
need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
for index, span in enumerate(need_ocr_list):
ocr_text, ocr_score = ocr_res_list[index]
if ocr_score > OcrConfidence.min_confidence:
span['content'] = ocr_text
span['score'] = float(f"{ocr_score:.3f}")
else:
span['content'] = ''
span['score'] = 0.0
"""分段"""
para_split(middle_json["pdf_info"])
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
"""清理内存"""
pdf_doc.close()
clean_memory(get_device())
return middle_json
def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
return_dict = {
'preproc_blocks': blocks,
'page_idx': page_id,
'page_size': [page_w, page_h],
'discarded_blocks': discarded_blocks,
}
return return_dict
\ No newline at end of file
class AtomicModel:
Layout = "layout"
MFD = "mfd"
MFR = "mfr"
OCR = "ocr"
Table = "table"
import copy
from loguru import logger
from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
from mineru.utils.language import detect_lang
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
LIST_END_FLAG = ('.', '。', ';', ';')
class ListLineTag:
IS_LIST_START_LINE = 'is_list_start_line'
IS_LIST_END_LINE = 'is_list_end_line'
def __process_blocks(blocks):
# 对所有block预处理
# 1.通过title和interline_equation将block分组
# 2.bbox边界根据line信息重置
result = []
current_group = []
for i in range(len(blocks)):
current_block = blocks[i]
# 如果当前块是 text 类型
if current_block['type'] == 'text':
current_block['bbox_fs'] = copy.deepcopy(current_block['bbox'])
if 'lines' in current_block and len(current_block['lines']) > 0:
current_block['bbox_fs'] = [
min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']]),
]
current_group.append(current_block)
# 检查下一个块是否存在
if i + 1 < len(blocks):
next_block = blocks[i + 1]
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
if next_block['type'] in ['title', 'interline_equation']:
result.append(current_group)
current_group = []
# 处理最后一个 group
if current_group:
result.append(current_group)
return result
def __is_list_or_index_block(block):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
# index block 是一种特殊的list block
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if len(block['lines']) >= 2:
first_line = block['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
block_height = block['bbox_fs'][3] - block['bbox_fs'][1]
page_weight, page_height = block['page_size']
left_close_num = 0
left_not_close_num = 0
right_not_close_num = 0
right_close_num = 0
lines_text_list = []
center_close_num = 0
external_sides_not_close_num = 0
multiple_para_flag = False
last_line = block['lines'][-1]
if page_weight == 0:
block_weight_radio = 0
else:
block_weight_radio = block_weight / page_weight
# logger.info(f"block_weight_radio: {block_weight_radio}")
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if (
first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2
and abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2
and block['bbox_fs'][2] - last_line['bbox'][2] > line_height
):
multiple_para_flag = True
block_text = ''
for line in block['lines']:
line_text = ''
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.TEXT:
line_text += span['content'].strip()
# 添加所有文本,包括空行,保持与block['lines']长度一致
lines_text_list.append(line_text)
block_text = ''.join(lines_text_list)
block_lang = detect_lang(block_text)
# logger.info(f"block_lang: {block_lang}")
for line in block['lines']:
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
if (
line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
):
external_sides_not_close_num += 1
if abs(line_mid_x - block_mid_x) < line_height / 2:
center_close_num += 1
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
left_close_num += 1
elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
left_not_close_num += 1
# 计算右侧是否顶格
if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
right_close_num += 1
else:
# 类中文没有超长单词的情况,可以用统一的阈值
if block_lang in ['zh', 'ja', 'ko']:
closed_area = 0.26 * block_weight
else:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
# block宽的阈值可以小些,block窄的阈值要大
if block_weight_radio >= 0.5:
closed_area = 0.26 * block_weight
else:
closed_area = 0.36 * block_weight
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag = False
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag = False
num_start_count = 0
num_end_count = 0
flag_end_count = 0
if len(lines_text_list) > 0:
for line_text in lines_text_list:
if len(line_text) > 0:
if line_text[-1] in LIST_END_FLAG:
flag_end_count += 1
if line_text[0].isdigit():
num_start_count += 1
if line_text[-1].isdigit():
num_end_count += 1
if (
num_start_count / len(lines_text_list) >= 0.8
or num_end_count / len(lines_text_list) >= 0.8
):
line_num_flag = True
if flag_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if (
left_close_num / len(block['lines']) >= 0.8
or right_close_num / len(block['lines']) >= 0.8
) and line_num_flag:
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.INDEX
# 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
# 补充条件block的长宽比有要求
elif (
external_sides_not_close_num >= 2
and center_close_num == len(block['lines'])
and external_sides_not_close_num / len(block['lines']) >= 0.5
and block_height / block_weight > 0.4
):
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.LIST
elif (
left_close_num >= 2
and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
and not multiple_para_flag
# and block_weight_radio > 0.27
):
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if left_close_num / len(block['lines']) > 0.8:
# 这种是每个item只有一行,且左边都贴边的短item list
if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif line_end_flag:
for i, line in enumerate(block['lines']):
if (
len(lines_text_list[i]) > 0
and lines_text_list[i][-1] in LIST_END_FLAG
):
line[ListLineTag.IS_LIST_END_LINE] = True
if i + 1 < len(block['lines']):
block['lines'][i + 1][
ListLineTag.IS_LIST_START_LINE
] = True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else:
line_start_flag = False
for i, line in enumerate(block['lines']):
if line_start_flag:
line[ListLineTag.IS_LIST_START_LINE] = True
line_start_flag = False
if (
abs(block['bbox_fs'][2] - line['bbox'][2])
> 0.1 * block_weight
):
line[ListLineTag.IS_LIST_END_LINE] = True
line_start_flag = True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
elif num_start_count >= 2 and num_start_count == flag_end_count:
for i, line in enumerate(block['lines']):
if len(lines_text_list[i]) > 0:
if lines_text_list[i][0].isdigit():
line[ListLineTag.IS_LIST_START_LINE] = True
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
else:
# 正常有缩进的list处理
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
return BlockType.LIST
else:
return BlockType.TEXT
else:
return BlockType.TEXT
def __merge_2_text_blocks(block1, block2):
if len(block1['lines']) > 0:
first_line = block1['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block1_weight = block1['bbox'][2] - block1['bbox'][0]
block2_weight = block2['bbox'][2] - block2['bbox'][0]
min_block_weight = min(block1_weight, block2_weight)
if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
last_line = block2['lines'][-1]
if len(last_line['spans']) > 0:
last_span = last_line['spans'][-1]
line_height = last_line['bbox'][3] - last_line['bbox'][1]
if len(first_line['spans']) > 0:
first_span = first_line['spans'][0]
if len(first_span['content']) > 0:
span_start_with_num = first_span['content'][0].isdigit()
span_start_with_big_char = first_span['content'][0].isupper()
if (
# 上一个block的最后一个line的右边界和block的右边界差距不超过line_height
abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height
# 上一个block的最后一个span不是以特定符号结尾
and not last_span['content'].endswith(LINE_STOP_FLAG)
# 两个block宽度差距超过2倍也不合并
and abs(block1_weight - block2_weight) < min_block_weight
# 下一个block的第一个字符是数字
and not span_start_with_num
# 下一个block的第一个字符是大写字母
and not span_start_with_big_char
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[SplitFlag.CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[SplitFlag.LINES_DELETED] = True
return block1, block2
def __merge_2_list_blocks(block1, block2):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[SplitFlag.CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[SplitFlag.LINES_DELETED] = True
return block1, block2
def __is_list_group(text_blocks_group):
# list group的特征是一个group内的所有block都满足以下条件
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
for block in text_blocks_group:
if len(block['lines']) > 3:
return False
return True
def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups:
if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list or index block
for block in text_blocks_group:
block_type = __is_list_or_index_block(block)
block['type'] = block_type
# logger.info(f"{block['type']}:{block}")
if len(text_blocks_group) > 1:
# 在合并前判断这个group 是否是一个 list group
is_list_group = __is_list_group(text_blocks_group)
# 倒序遍历
for i in range(len(text_blocks_group) - 1, -1, -1):
current_block = text_blocks_group[i]
# 检查是否有前一个块
if i - 1 >= 0:
prev_block = text_blocks_group[i - 1]
if (
current_block['type'] == 'text'
and prev_block['type'] == 'text'
and not is_list_group
):
__merge_2_text_blocks(current_block, prev_block)
elif (
current_block['type'] == BlockType.LIST
and prev_block['type'] == BlockType.LIST
) or (
current_block['type'] == BlockType.INDEX
and prev_block['type'] == BlockType.INDEX
):
__merge_2_list_blocks(current_block, prev_block)
else:
continue
def para_split(page_info_list):
all_blocks = []
for page_info in page_info_list:
blocks = copy.deepcopy(page_info['preproc_blocks'])
for block in blocks:
block['page_num'] = page_info['page_idx']
block['page_size'] = page_info['page_size']
all_blocks.extend(blocks)
__para_merge_page(all_blocks)
for page_info in page_info_list:
page_info['para_blocks'] = []
for block in all_blocks:
if block['page_num'] == page_info['page_idx']:
page_info['para_blocks'].append(block)
if __name__ == '__main__':
input_blocks = []
# 调用函数
groups = __process_blocks(input_blocks)
for group_index, group in enumerate(groups):
print(f'Group {group_index}: {group}')
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright (c) Opendatalab. All rights reserved.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import re
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.version import __version__
def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
"""将token转换为页面信息"""
# 解析token,提取坐标和类型
# 假设token格式为:<|box_start|>x0 y0 x1 y1<|box_end|><|ref_start|>type<|ref_end|><|md_start|>content<|md_end|>
# 这里需要根据实际的token格式进行解析
# 提取所有完整块,每个块从<|box_start|>开始到<|md_end|>或<|im_end|>结束
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
width, height = map(int, page.get_size())
magic_model = MagicModel(token, width, height)
image_blocks = magic_model.get_image_blocks()
table_blocks = magic_model.get_table_blocks()
title_blocks = magic_model.get_title_blocks()
text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks()
all_spans = magic_model.get_all_spans()
# 对image/table/interline_equation的span截图
for span in all_spans:
if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
page_blocks = []
page_blocks.extend([*image_blocks, *table_blocks, *title_blocks, *text_blocks, *interline_equation_blocks])
# 对page_blocks根据index的值进行排序
page_blocks.sort(key=lambda x: x["index"])
page_info = {"para_blocks": page_blocks, "discarded_blocks": [], "page_size": [width, height], "page_idx": page_index}
return page_info
def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
for index, token in enumerate(token_list):
page = pdf_doc[index]
image_dict = images_list[index]
page_info = token_to_page_info(token, image_dict, page, image_writer, index)
middle_json["pdf_info"].append(page_info)
# 关闭pdf文档
pdf_doc.close()
return middle_json
if __name__ == "__main__":
output = r"<|box_start|>088 119 472 571<|box_end|><|ref_start|>image<|ref_end|><|md_start|>![]('img_url')<|md_end|>\n<|box_start|>079 582 482 608<|box_end|><|ref_start|>image_caption<|ref_end|><|md_start|>Fig. 2. (a) Schematic of the change in the FDC over time, and (b) definition of model parameters.<|md_end|>\n<|box_start|>079 624 285 638<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.2. Zero flow day analysis<|md_end|>\n<|box_start|>079 656 482 801<|box_end|><|ref_start|>text<|ref_end|><|md_start|>A notable feature of Fig. 1 is the increase in the number of zero flow days. A similar approach to Eq. (2), using an inverse sigmoidal function was employed to assess the impact of afforestation on the number of zero flow days per year \((N_{\mathrm{zero}})\). In this case, the left hand side of Eq. (2) is replaced by \(N_{\mathrm{zero}}\) and \(b\) and \(S\) are constrained to negative as \(N_{\mathrm{zero}}\) decreases as rainfall increases, and increases with plantation growth:<|md_end|>\n<|box_start|>076 813 368 853<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nN_{\mathrm{zero}}=a+b(\Delta P)+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}}{S}\right)}\n\]<|md_end|>\n<|box_start|>079 865 482 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For the average pre-treatment condition \(\Delta P=0\) and \(T=0\), \(N_{\mathrm{zero}}\) approximately equals \(a\). \(Y\) gives<|md_end|>\n<|box_start|>525 119 926 215<|box_end|><|ref_start|>text<|ref_end|><|md_start|>the magnitude of change in zero flow days due to afforestation, and \(S\) describes the shape of the response. For the average climate condition \(\Delta P=0\), \(a+Y\) becomes the number of zero flow days when the new equilibrium condition under afforestation is reached.<|md_end|>\n<|box_start|>525 240 704 253<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.3. Statistical analyses<|md_end|>\n<|box_start|>525 271 926 368<|box_end|><|ref_start|>text<|ref_end|><|md_start|>The coefficient of efficiency \((E)\) (Nash and Sutcliffe, 1970; Chiew and McMahon, 1993; Legates and McCabe, 1999) was used as the 'goodness of fit' measure to evaluate the fit between observed and predicted flow deciles (2) and zero flow days (3). \(E\) is given by:<|md_end|>\n<|box_start|>520 375 735 415<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nE=1.0-\frac{\sum_{i=1}^{N}(O_{i}-P_{i})^{2}}{\sum_{i=1}^{N}(O_{i}-\bar{O})^{2}}\n\]<|md_end|>\n<|box_start|>525 424 926 601<|box_end|><|ref_start|>text<|ref_end|><|md_start|>where \(O\) are observed data, \(P\) are predicted values, and \(\bar{O}\) is the mean for the entire period. \(E\) is unity minus the ratio of the mean square error to the variance in the observed data, and ranges from \(-\infty\) to 1.0. Higher values indicate greater agreement between observed and predicted data as per the coefficient of determination \((r^{2})\). \(E\) is used in preference to \(r^{2}\) in evaluating hydrologic modelling because it is a measure of the deviation from the 1:1 line. As \(E\) is always \(<r^{2}\) we have arbitrarily considered \(E>0.7\) to indicate adequate model fits.<|md_end|>\n<|box_start|>525 603 926 731<|box_end|><|ref_start|>text<|ref_end|><|md_start|>It is important to assess the significance of the model parameters to check the model assumptions that rainfall and forest age are driving changes in the FDC. The model (2) was split into simplified forms, where only the rainfall or time terms were included by setting \(b=0\), as shown in Eq. (5), or \(Y=0\) as shown in Eq. (6). The component models (5) and (6) were then tested against the complete model, (2).<|md_end|>\n<|box_start|>520 739 735 778<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}^{\prime}}{S}\right)}\n\]<|md_end|>\n<|box_start|>525 787 553 799<|box_end|><|ref_start|>text<|ref_end|><|md_start|>and<|md_end|>\n<|box_start|>520 807 646 825<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+b\Delta P\n\]<|md_end|>\n<|box_start|>525 833 926 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For both the flow duration curve analysis and zero flow days analysis, a \(t\)-test was then performed to test whether (5) and (6) were significantly different to (2). A critical value of \(t\) exceeding the calculated \(t\)-value<|md_end|><|im_end|>"
p_info = token_to_page_info(output)
# 将blocks 转换为json文本
import json
json_str = json.dumps(p_info, ensure_ascii=False, indent=4)
print(json_str)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright (c) Opendatalab. All rights reserved.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment