Commit ea5cb65a authored by myhloli's avatar myhloli
Browse files

refactor: enhance document parsing by supporting multiple PDF files and...

refactor: enhance document parsing by supporting multiple PDF files and improving method organization
parent 0a899f1a
# Copyright (c) Opendatalab. All rights reserved.
def result_to_middle_json(model_json, images_list, pdf_doc, image_writer):
pass
\ No newline at end of file
...@@ -2,12 +2,14 @@ import os ...@@ -2,12 +2,14 @@ import os
import time import time
import numpy as np import numpy as np
import torch import torch
from pypdfium2 import PdfDocument
from mineru.backend.pipeline.model_init import MineruPipelineModel from mineru.backend.pipeline.model_init import MineruPipelineModel
from .model_json_to_middle_json import result_to_middle_json
from ...utils.pdf_classify import classify
from ...utils.pdf_image_tools import pdf_page_to_image
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from loguru import logger from loguru import logger
...@@ -18,6 +20,11 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config, ...@@ -18,6 +20,11 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir, get_local_models_dir,
get_table_recog_config) get_table_recog_config)
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
class ModelSingleton: class ModelSingleton:
_instance = None _instance = None
_models = {} _models = {}
...@@ -76,117 +83,92 @@ def custom_model_init( ...@@ -76,117 +83,92 @@ def custom_model_init(
return custom_model return custom_model
def doc_analyze( def doc_analyze(
dataset: Dataset, pdf_bytes_list,
ocr: bool = False, lang_list,
start_page_id=0, parse_method: str = 'auto',
end_page_id=None, formula_enable=None,
lang=None, table_enable=None,
formula_enable=None,
table_enable=None,
): ):
end_page_id = ( """
end_page_id 统一处理文档分析函数,根据输入参数类型决定处理单个数据集还是多个数据集
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1 Args:
) dataset_or_datasets: 单个Dataset对象或Dataset对象列表
parse_method: 解析方法,'auto'/'ocr'/'txt'
formula_enable: 是否启用公式识别
table_enable: 是否启用表格识别
Returns:
单个dataset时返回单个model_json,多个dataset时返回model_json列表
"""
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100)) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
images = []
page_wh_list = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(images))]
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
else:
batch_images = [images_with_extra_info]
# 收集所有页面信息
all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
# 确定OCR设置
_ocr = False
if parse_method == 'auto':
if classify(pdf_bytes) == 'ocr':
_ocr = True
elif parse_method == 'ocr':
_ocr = True
_lang = lang_list[pdf_idx]
# 收集每个数据集中的页面
pdf_doc = PdfDocument(pdf_bytes)
for page_idx in range(len(pdf_doc)):
page_data = pdf_doc[page_idx]
img_dict = pdf_page_to_image(page_data)
all_pages_info.append((
pdf_idx, page_idx,
img_dict['img_pil'], _ocr, _lang,
img_dict['scale']
))
# 准备批处理
images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [
images_with_extra_info[i:i + batch_size]
for i in range(0, len(images_with_extra_info), batch_size)
]
# 执行批处理
results = [] results = []
processed_images_count = 0 processed_images_count = 0
for index, batch_image in enumerate(batch_images): for index, batch_image in enumerate(batch_images):
processed_images_count += len(batch_image) processed_images_count += len(batch_image)
logger.info(f'Batch {index + 1}/{len(batch_images)}: {processed_images_count} pages/{len(images_with_extra_info)} pages') logger.info(
result = may_batch_image_analyze(batch_image, formula_enable, table_enable) f'Batch {index + 1}/{len(batch_images)}: '
results.extend(result) f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
)
model_json = [] batch_results = may_batch_image_analyze(batch_image, formula_enable, table_enable)
for index in range(len(dataset)): results.extend(batch_results)
if start_page_id <= index <= end_page_id:
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
else:
result = []
page_height = 0
page_width = 0
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
return model_json # 构建返回结果
def batch_doc_analyze(
datasets: list[Dataset],
parse_method: str = 'auto',
lang=None,
formula_enable=None,
table_enable=None,
):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
batch_size = MIN_BATCH_INFERENCE_SIZE
page_wh_list = []
images_with_extra_info = [] # 多数据集模式:按数据集分组结果
for dataset in datasets: infer_results = [[] for _ in datasets]
ocr = False for i, page_info in enumerate(all_pages_info):
if parse_method == 'auto': pdf_idx, page_idx, pil_img, _, _ = page_info
if dataset.classify() == 'txt': result = results[i]
ocr = False
elif dataset.classify() == 'ocr':
ocr = True
elif parse_method == 'ocr':
ocr = True
elif parse_method == 'txt':
ocr = False
_lang = dataset._lang page_info_dict = {'page_no': page_idx, 'width': pil_img.get_width(), 'height': pil_img.get_height()}
page_dict = {'layout_dets': result, 'page_info': page_info_dict}
infer_results[pdf_idx].append(page_dict)
for index in range(len(dataset)): middle_json_list = []
page_data = dataset.get_page(index) for model_json in infer_results:
img_dict = page_data.get_image() middle_json = result_to_middle_json(model_json)
page_wh_list.append((img_dict['width'], img_dict['height'])) middle_json_list.append(middle_json)
images_with_extra_info.append((img_dict['img'], ocr, _lang))
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)] return middle_json_list, infer_results
results = []
processed_images_count = 0
for index, batch_image in enumerate(batch_images):
processed_images_count += len(batch_image)
logger.info(f'Batch {index + 1}/{len(batch_images)}: {processed_images_count} pages/{len(images_with_extra_info)} pages')
result = may_batch_image_analyze(batch_image, formula_enable, table_enable)
results.extend(result)
infer_results = []
for index in range(len(datasets)):
dataset = datasets[index]
model_json = []
for i in range(len(dataset)):
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
page_info = {'page_no': i, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
infer_results.append(model_json)
return infer_results
def may_batch_image_analyze( def may_batch_image_analyze(
......
...@@ -7,7 +7,8 @@ from pathlib import Path ...@@ -7,7 +7,8 @@ from pathlib import Path
import pypdfium2 as pdfium import pypdfium2 as pdfium
from loguru import logger from loguru import logger
from ..api.vlm_middle_json_mkcontent import union_make from ..api.vlm_middle_json_mkcontent import union_make
from ..backend.vlm.vlm_analyze import doc_analyze from ..backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
from ..backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
from ..data.data_reader_writer import FileBasedDataWriter from ..data.data_reader_writer import FileBasedDataWriter
from ..utils.draw_bbox import draw_layout_bbox, draw_span_bbox from ..utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from ..utils.enum_class import MakeMode from ..utils.enum_class import MakeMode
...@@ -28,8 +29,8 @@ def read_fn(path: Path): ...@@ -28,8 +29,8 @@ def read_fn(path: Path):
raise Exception(f"Unknown file suffix: {path.suffix}") raise Exception(f"Unknown file suffix: {path.suffix}")
def prepare_env(output_dir, pdf_file_name): def prepare_env(output_dir, pdf_file_name, parse_method):
local_parent_dir = os.path.join(output_dir, pdf_file_name) local_parent_dir = os.path.join(output_dir, pdf_file_name, parse_method)
local_image_dir = os.path.join(str(local_parent_dir), "images") local_image_dir = os.path.join(str(local_parent_dir), "images")
local_md_dir = local_parent_dir local_md_dir = local_parent_dir
...@@ -70,13 +71,17 @@ def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page ...@@ -70,13 +71,17 @@ def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page
def do_parse( def do_parse(
output_dir, output_dir,
pdf_file_name, pdf_file_names: list[str],
pdf_bytes, pdf_bytes_list: list[bytes],
p_lang_list: list[str],
backend="pipeline", backend="pipeline",
model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415", # TODO: change to formal path after release. model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415", # TODO: change to formal path after release.
parse_method="auto",
p_formula_enable=True,
p_table_enable=True,
server_url=None, server_url=None,
f_draw_layout_bbox=True, f_draw_layout_bbox=True,
f_draw_span_bbox=False, f_draw_span_bbox=True,
f_dump_md=True, f_dump_md=True,
f_dump_middle_json=True, f_dump_middle_json=True,
f_dump_model_output=True, f_dump_model_output=True,
...@@ -86,58 +91,114 @@ def do_parse( ...@@ -86,58 +91,114 @@ def do_parse(
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
): ):
if backend == 'pipeline':
f_draw_span_bbox = True if backend == "pipeline":
for pdf_bytes in pdf_bytes_list:
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id) pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name) middle_json_list, infer_results = pipeline_doc_analyze(pdf_bytes_list, p_lang_list, parse_method=parse_method, formula_enable=p_formula_enable,table_enable=p_table_enable)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir) for idx, middle_json in enumerate(middle_json_list):
pdf_file_name = pdf_file_names[idx]
middle_json, infer_result = doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url) model_json = infer_results[idx]
pdf_info = middle_json["pdf_info"] local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf") pdf_info = middle_json["pdf_info"]
if f_draw_span_bbox: if f_draw_layout_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf") draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_dump_orig_pdf: if f_draw_span_bbox:
md_writer.write( draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
f"{pdf_file_name}_origin.pdf",
pdf_bytes, if f_dump_orig_pdf:
) md_writer.write(
f"{pdf_file_name}_origin.pdf",
if f_dump_md: pdf_bytes,
image_dir = str(os.path.basename(local_image_dir)) )
md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string( if f_dump_md:
f"{pdf_file_name}.md", image_dir = str(os.path.basename(local_image_dir))
md_content_str, md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
) md_writer.write_string(
f"{pdf_file_name}.md",
if f_dump_content_list: md_content_str,
image_dir = str(os.path.basename(local_image_dir)) )
content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
md_writer.write_string( if f_dump_content_list:
f"{pdf_file_name}_content_list.json", image_dir = str(os.path.basename(local_image_dir))
json.dumps(content_list, ensure_ascii=False, indent=4), content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
) md_writer.write_string(
f"{pdf_file_name}_content_list.json",
if f_dump_middle_json: json.dumps(content_list, ensure_ascii=False, indent=4),
md_writer.write_string( )
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4), if f_dump_middle_json:
) md_writer.write_string(
f"{pdf_file_name}_middle.json",
if f_dump_model_output: json.dumps(middle_json, ensure_ascii=False, indent=4),
model_output = ("\n" + "-" * 50 + "\n").join(infer_result) )
md_writer.write_string(
f"{pdf_file_name}_model_output.txt", if f_dump_model_output:
model_output, md_writer.write_string(
) f"{pdf_file_name}_model.json",
json.dumps(model_json, ensure_ascii=False, indent=4),
logger.info(f"local output dir is {local_md_dir}") )
logger.info(f"local output dir is {local_md_dir}")
else:
f_draw_span_bbox = False
parse_method = "vlm"
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
pdf_info = middle_json["pdf_info"]
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
if f_dump_orig_pdf:
md_writer.write(
f"{pdf_file_name}_origin.pdf",
pdf_bytes,
)
if f_dump_md:
image_dir = str(os.path.basename(local_image_dir))
md_content_str = union_make(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
if f_dump_content_list:
image_dir = str(os.path.basename(local_image_dir))
content_list = union_make(pdf_info, MakeMode.STANDARD_FORMAT, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if f_dump_model_output:
model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
md_writer.write_string(
f"{pdf_file_name}_model_output.txt",
model_output,
)
logger.info(f"local output dir is {local_md_dir}")
return infer_result return infer_result
......
# Copyright (c) Opendatalab. All rights reserved.
import re
from io import BytesIO
import numpy as np
import pypdfium2 as pdfium
from loguru import logger
from pdfminer.high_level import extract_text
from pdfminer.layout import LAParams
def classify(pdf_bytes):
"""
判断PDF文件是可以直接提取文本还是需要OCR
Args:
pdf_bytes: PDF文件的字节数据
Returns:
str: 'txt' 表示可以直接提取文本,'ocr' 表示需要OCR
"""
try:
# 从字节数据加载PDF
sample_pdf_bytes = extract_pages(pdf_bytes)
pdf = pdfium.PdfDocument(sample_pdf_bytes)
# 获取PDF页数
page_count = len(pdf)
# 如果PDF页数为0,直接返回OCR
if page_count == 0:
return 'ocr'
# 总字符数
total_chars = 0
# 清理后的总字符数
cleaned_total_chars = 0
# 检查的页面数(最多检查10页)
pages_to_check = min(page_count, 10)
# 检查前几页的文本
for i in range(pages_to_check):
page = pdf[i]
text_page = page.get_textpage()
text = text_page.get_text_bounded()
total_chars += len(text)
# 清理提取的文本,移除空白字符
cleaned_text = re.sub(r'\s+', '', text)
cleaned_total_chars += len(cleaned_text)
# 计算平均每页字符数
# avg_chars_per_page = total_chars / pages_to_check
avg_cleaned_chars_per_page = cleaned_total_chars / pages_to_check
# 设置阈值:如果每页平均少于50个有效字符,认为需要OCR
chars_threshold = 50
# logger.debug(f"PDF分析: 平均每页{avg_chars_per_page:.1f}字符, 清理后{avg_cleaned_chars_per_page:.1f}字符")
if (avg_cleaned_chars_per_page < chars_threshold) or detect_invalid_chars(sample_pdf_bytes):
return 'ocr'
else:
return 'txt'
except Exception as e:
logger.error(f"判断PDF类型时出错: {e}")
# 出错时默认使用OCR
return 'ocr'
def extract_pages(src_pdf_bytes: bytes) -> bytes:
"""
从PDF字节数据中随机提取最多10页,返回新的PDF字节数据
Args:
src_pdf_bytes: PDF文件的字节数据
Returns:
bytes: 提取页面后的PDF字节数据
"""
# 从字节数据加载PDF
pdf = pdfium.PdfDocument(src_pdf_bytes)
# 获取PDF页数
total_page = len(pdf)
if total_page == 0:
# 如果PDF没有页面,直接返回空文档
logger.warning("PDF is empty, return empty document")
return b''
# 选择最多10页
select_page_cnt = min(10, total_page)
# 从总页数中随机选择页面
page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
# 创建一个新的PDF文档
sample_docs = pdfium.PdfDocument.new()
try:
# 将选择的页面导入新文档
sample_docs.import_pages(pdf, page_indices)
# 将新PDF保存到内存缓冲区
output_buffer = BytesIO()
sample_docs.save(output_buffer)
# 获取字节数据
return output_buffer.getvalue()
except Exception as e:
logger.exception(e)
return b'' # 出错时返回空字节
def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
""""
检测PDF中是否包含非法字符
"""
'''pdfminer比较慢,需要先随机抽取10页左右的sample'''
# sample_pdf_bytes = extract_pages(src_pdf_bytes)
sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
laparams = LAParams(
line_overlap=0.5,
char_margin=2.0,
line_margin=0.5,
word_margin=0.1,
boxes_flow=None,
detect_vertical=False,
all_texts=False,
)
text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
text = text.replace("\n", "")
# logger.info(text)
'''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
cid_pattern = re.compile(r'\(cid:\d+\)')
matches = cid_pattern.findall(text)
cid_count = len(matches)
cid_len = sum(len(match) for match in matches)
text_len = len(text)
if text_len == 0:
cid_chars_radio = 0
else:
cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
# logger.debug(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
'''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
if cid_chars_radio > 0.05:
return True # 乱码文档
else:
return False # 正常文档
if __name__ == '__main__':
with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f:
p_bytes = f.read()
logger.info(f"PDF分类结果: {classify(p_bytes)}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment