Commit bd927919 authored by myhloli's avatar myhloli
Browse files

refactor: rename init file and update app.py to enable parsing method

parent f5016508
def dict_to_list(input_dict):
items_list = []
for _, item in input_dict.items():
items_list.append(item)
return items_list
def get_scale_ratio(model_page_info, page):
pix = page.get_pixmap(dpi=72)
pymu_width = int(pix.w)
pymu_height = int(pix.h)
width_from_json = model_page_info['page_info']['width']
height_from_json = model_page_info['page_info']['height']
horizontal_scale_ratio = width_from_json / pymu_width
vertical_scale_ratio = height_from_json / pymu_height
return horizontal_scale_ratio, vertical_scale_ratio
import fitz
from magic_pdf.config.constants import CROSS_PAGE
from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
ContentType)
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.magic_model import MagicModel
def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
new_rgb = []
for item in rgb_config:
item = float(item) / 255
new_rgb.append(item)
page_data = bbox_list[i]
for bbox in page_data:
x0, y0, x1, y1 = bbox
rect_coords = fitz.Rect(x0, y0, x1, y1) # Define the rectangle
if fill_config:
page.draw_rect(
rect_coords,
color=None,
fill=new_rgb,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=new_rgb,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox=True):
new_rgb = []
for item in rgb_config:
item = float(item) / 255
new_rgb.append(item)
page_data = bbox_list[i]
for j, bbox in enumerate(page_data):
x0, y0, x1, y1 = bbox
rect_coords = fitz.Rect(x0, y0, x1, y1) # Define the rectangle
if draw_bbox:
if fill_config:
page.draw_rect(
rect_coords,
color=None,
fill=new_rgb,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=new_rgb,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
page.insert_text(
(x1 + 2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
) # Insert the index in the top left corner of the rectangle
def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
dropped_bbox_list = []
tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], []
imgs_list, imgs_body_list, imgs_caption_list = [], [], []
imgs_footnote_list = []
titles_list = []
texts_list = []
interequations_list = []
lists_list = []
indexs_list = []
for page in pdf_info:
page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = []
texts = []
interequations = []
lists = []
indices = []
for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox'])
dropped_bbox_list.append(page_dropped_list)
for block in page['para_blocks']:
bbox = block['bbox']
if block['type'] == BlockType.Table:
tables.append(bbox)
for nested_block in block['blocks']:
bbox = nested_block['bbox']
if nested_block['type'] == BlockType.TableBody:
tables_body.append(bbox)
elif nested_block['type'] == BlockType.TableCaption:
tables_caption.append(bbox)
elif nested_block['type'] == BlockType.TableFootnote:
tables_footnote.append(bbox)
elif block['type'] == BlockType.Image:
imgs.append(bbox)
for nested_block in block['blocks']:
bbox = nested_block['bbox']
if nested_block['type'] == BlockType.ImageBody:
imgs_body.append(bbox)
elif nested_block['type'] == BlockType.ImageCaption:
imgs_caption.append(bbox)
elif nested_block['type'] == BlockType.ImageFootnote:
imgs_footnote.append(bbox)
elif block['type'] == BlockType.Title:
titles.append(bbox)
elif block['type'] == BlockType.Text:
texts.append(bbox)
elif block['type'] == BlockType.InterlineEquation:
interequations.append(bbox)
elif block['type'] == BlockType.List:
lists.append(bbox)
elif block['type'] == BlockType.Index:
indices.append(bbox)
tables_list.append(tables)
tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption)
tables_footnote_list.append(tables_footnote)
imgs_list.append(imgs)
imgs_body_list.append(imgs_body)
imgs_caption_list.append(imgs_caption)
imgs_footnote_list.append(imgs_footnote)
titles_list.append(titles)
texts_list.append(texts)
interequations_list.append(interequations)
lists_list.append(lists)
indexs_list.append(indices)
layout_bbox_list = []
table_type_order = {
'table_caption': 1,
'table_body': 2,
'table_footnote': 3
}
for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
if block['type'] in [
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.List,
BlockType.Index,
]:
bbox = block['bbox']
page_block_list.append(bbox)
elif block['type'] in [BlockType.Image]:
for sub_block in block['blocks']:
bbox = sub_block['bbox']
page_block_list.append(bbox)
elif block['type'] in [BlockType.Table]:
sorted_blocks = sorted(block['blocks'], key=lambda x: table_type_order[x['type']])
for sub_block in sorted_blocks:
bbox = sub_block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158], True)
# draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True) # color !
draw_bbox_without_number(i, tables_body_list, page, [204, 204, 0], True)
draw_bbox_without_number(i, tables_caption_list, page, [255, 255, 102], True)
draw_bbox_without_number(i, tables_footnote_list, page, [229, 255, 204], True)
# draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255], True)
draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102], True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0], True)
draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)
draw_bbox_with_number(
i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False
)
# Save the PDF
pdf_docs.save(f'{out_path}/{filename}')
def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
text_list = []
inline_equation_list = []
interline_equation_list = []
image_list = []
table_list = []
dropped_list = []
next_page_text_list = []
next_page_inline_equation_list = []
def get_span_info(span):
if span['type'] == ContentType.Text:
if span.get(CROSS_PAGE, False):
next_page_text_list.append(span['bbox'])
else:
page_text_list.append(span['bbox'])
elif span['type'] == ContentType.InlineEquation:
if span.get(CROSS_PAGE, False):
next_page_inline_equation_list.append(span['bbox'])
else:
page_inline_equation_list.append(span['bbox'])
elif span['type'] == ContentType.InterlineEquation:
page_interline_equation_list.append(span['bbox'])
elif span['type'] == ContentType.Image:
page_image_list.append(span['bbox'])
elif span['type'] == ContentType.Table:
page_table_list.append(span['bbox'])
for page in pdf_info:
page_text_list = []
page_inline_equation_list = []
page_interline_equation_list = []
page_image_list = []
page_table_list = []
page_dropped_list = []
# 将跨页的span放到移动到下一页的列表中
if len(next_page_text_list) > 0:
page_text_list.extend(next_page_text_list)
next_page_text_list.clear()
if len(next_page_inline_equation_list) > 0:
page_inline_equation_list.extend(next_page_inline_equation_list)
next_page_inline_equation_list.clear()
# 构造dropped_list
for block in page['discarded_blocks']:
if block['type'] == BlockType.Discarded:
for line in block['lines']:
for span in line['spans']:
page_dropped_list.append(span['bbox'])
dropped_list.append(page_dropped_list)
# 构造其余useful_list
# for block in page['para_blocks']: # span直接用分段合并前的结果就可以
for block in page['preproc_blocks']:
if block['type'] in [
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.List,
BlockType.Index,
]:
for line in block['lines']:
for span in line['spans']:
get_span_info(span)
elif block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
for line in sub_block['lines']:
for span in line['spans']:
get_span_info(span)
text_list.append(page_text_list)
inline_equation_list.append(page_inline_equation_list)
interline_equation_list.append(page_interline_equation_list)
image_list.append(page_image_list)
table_list.append(page_table_list)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
# 获取当前页面的数据
draw_bbox_without_number(i, text_list, page, [255, 0, 0], False)
draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], False)
draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], False)
draw_bbox_without_number(i, image_list, page, [255, 204, 0], False)
draw_bbox_without_number(i, table_list, page, [204, 0, 255], False)
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
# Save the PDF
pdf_docs.save(f'{out_path}/{filename}')
def draw_model_bbox(model_list, dataset: Dataset, out_path, filename):
dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
titles_list = []
texts_list = []
interequations_list = []
magic_model = MagicModel(model_list, dataset)
for i in range(len(model_list)):
page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], []
imgs_body, imgs_caption, imgs_footnote = [], [], []
titles = []
texts = []
interequations = []
page_info = magic_model.get_model_list(i)
layout_dets = page_info['layout_dets']
for layout_det in layout_dets:
bbox = layout_det['bbox']
if layout_det['category_id'] == CategoryId.Text:
texts.append(bbox)
elif layout_det['category_id'] == CategoryId.Title:
titles.append(bbox)
elif layout_det['category_id'] == CategoryId.TableBody:
tables_body.append(bbox)
elif layout_det['category_id'] == CategoryId.TableCaption:
tables_caption.append(bbox)
elif layout_det['category_id'] == CategoryId.TableFootnote:
tables_footnote.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageBody:
imgs_body.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageCaption:
imgs_caption.append(bbox)
elif layout_det['category_id'] == CategoryId.InterlineEquation_YOLO:
interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox)
elif layout_det['category_id'] == CategoryId.ImageFootnote:
imgs_footnote.append(bbox)
tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption)
tables_footnote_list.append(tables_footnote)
imgs_body_list.append(imgs_body)
imgs_caption_list.append(imgs_caption)
titles_list.append(titles)
texts_list.append(texts)
interequations_list.append(interequations)
dropped_bbox_list.append(page_dropped_list)
imgs_footnote_list.append(imgs_footnote)
for i in range(len(dataset)):
page = dataset.get_page(i)
draw_bbox_with_number(
i, dropped_bbox_list, page, [158, 158, 158], True
) # color !
draw_bbox_with_number(i, tables_body_list, page, [204, 204, 0], True)
draw_bbox_with_number(i, tables_caption_list, page, [255, 255, 102], True)
draw_bbox_with_number(i, tables_footnote_list, page, [229, 255, 204], True)
draw_bbox_with_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_with_number(i, imgs_caption_list, page, [102, 178, 255], True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102], True)
draw_bbox_with_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_with_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
# Save the PDF
dataset.dump_to_file(f'{out_path}/{filename}')
def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
for page in pdf_info:
page_line_list = []
for block in page['preproc_blocks']:
if block['type'] in [BlockType.Text]:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
if 'virtual_lines' in block:
if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
for line in block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
for line in sub_block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in sub_block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
for line in sub_block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
pdf_docs.save(f'{out_path}/{filename}')
def draw_char_bbox(pdf_bytes, out_path, filename):
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
for block in page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']:
for line in block['lines']:
for span in line['spans']:
for char in span['chars']:
char_bbox = char['bbox']
page.draw_rect(char_bbox, color=[1, 0, 0], fill=None, fill_opacity=1, width=0.3, overlay=True,)
pdf_docs.save(f'{out_path}/{filename}')
import hashlib
def compute_md5(file_bytes):
hasher = hashlib.md5()
hasher.update(file_bytes)
return hasher.hexdigest().upper()
def compute_sha256(input_string):
hasher = hashlib.sha256()
# 在Python3中,需要将字符串转化为字节对象才能被哈希函数处理
input_bytes = input_string.encode('utf-8')
hasher.update(input_bytes)
return hasher.hexdigest()
import json
import brotli
import base64
class JsonCompressor:
@staticmethod
def compress_json(data):
"""
Compress a json object and encode it with base64
"""
json_str = json.dumps(data)
json_bytes = json_str.encode('utf-8')
compressed = brotli.compress(json_bytes, quality=6)
compressed_str = base64.b64encode(compressed).decode('utf-8') # convert bytes to string
return compressed_str
@staticmethod
def decompress_json(compressed_str):
"""
Decode the base64 string and decompress the json object
"""
compressed = base64.b64decode(compressed_str.encode('utf-8')) # convert string to bytes
decompressed_bytes = brotli.decompress(compressed)
json_str = decompressed_bytes.decode('utf-8')
data = json.loads(json_str)
return data
import os
import unicodedata
if not os.getenv("FTLANG_CACHE"):
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
root_dir = os.path.dirname(current_dir)
ftlang_cache_dir = os.path.join(root_dir, 'resources', 'fasttext-langdetect')
os.environ["FTLANG_CACHE"] = str(ftlang_cache_dir)
# print(os.getenv("FTLANG_CACHE"))
from fast_langdetect import detect_language
def remove_invalid_surrogates(text):
# 移除无效的 UTF-16 代理对
return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
def detect_lang(text: str) -> str:
if len(text) == 0:
return ""
text = text.replace("\n", "")
text = remove_invalid_surrogates(text)
# print(text)
try:
lang_upper = detect_language(text)
except:
html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
lang_upper = detect_language(html_no_ctrl_chars)
try:
lang = lang_upper.lower()
except:
lang = ""
return lang
if __name__ == '__main__':
print(os.getenv("FTLANG_CACHE"))
print(detect_lang("This is a test."))
print(detect_lang("<html>This is a test</html>"))
print(detect_lang("这个是中文测试。"))
print(detect_lang("<html>这个是中文测试。</html>"))
print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
\ No newline at end of file
def float_gt(a, b):
if 0.0001 >= abs(a -b):
return False
return a > b
def float_equal(a, b):
if 0.0001 >= abs(a-b):
return True
return False
\ No newline at end of file
def ocr_escape_special_markdown_char(content):
"""
转义正文里对markdown语法有特殊意义的字符
"""
special_chars = ["*", "`", "~", "$"]
for char in special_chars:
content = content.replace(char, "\\" + char)
return content
def remove_non_official_s3_args(s3path):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> s3://abc/xxxx.json
"""
arr = s3path.split("?")
return arr[0]
def parse_s3path(s3path: str):
# from s3pathlib import S3Path
# p = S3Path(remove_non_official_s3_args(s3path))
# return p.bucket, p.key
s3path = remove_non_official_s3_args(s3path).strip()
if s3path.startswith(('s3://', 's3a://')):
prefix, path = s3path.split('://', 1)
bucket_name, key = path.split('/', 1)
return bucket_name, key
elif s3path.startswith('/'):
raise ValueError("The provided path starts with '/'. This does not conform to a valid S3 path format.")
else:
raise ValueError("Invalid S3 path format. Expected 's3://bucket-name/key' or 's3a://bucket-name/key'.")
def parse_s3_range_params(s3path: str):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> [0, 81350]
"""
arr = s3path.split("?bytes=")
if len(arr) == 1:
return None
return arr[1].split(",")
import fitz
import numpy as np
from loguru import logger
import re
from io import BytesIO
from pdfminer.high_level import extract_text
from pdfminer.layout import LAParams
def calculate_sample_count(total_page: int):
"""
根据总页数和采样率计算采样页面的数量。
"""
select_page_cnt = min(10, total_page)
return select_page_cnt
def extract_pages(src_pdf_bytes: bytes) -> fitz.Document:
pdf_docs = fitz.open("pdf", src_pdf_bytes)
total_page = len(pdf_docs)
if total_page == 0:
# 如果PDF没有页面,直接返回空文档
logger.warning("PDF is empty, return empty document")
return fitz.Document()
select_page_cnt = calculate_sample_count(total_page)
page_num = np.random.choice(total_page, select_page_cnt, replace=False)
sample_docs = fitz.Document()
try:
for index in page_num:
sample_docs.insert_pdf(pdf_docs, from_page=int(index), to_page=int(index))
except Exception as e:
logger.exception(e)
return sample_docs
def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
""""
检测PDF中是否包含非法字符
"""
'''pdfminer比较慢,需要先随机抽取10页左右的sample'''
sample_docs = extract_pages(src_pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
laparams = LAParams(
line_overlap=0.5,
char_margin=2.0,
line_margin=0.5,
word_margin=0.1,
boxes_flow=None,
detect_vertical=False,
all_texts=False,
)
text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
text = text.replace("\n", "")
# logger.info(text)
'''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
cid_pattern = re.compile(r'\(cid:\d+\)')
matches = cid_pattern.findall(text)
cid_count = len(matches)
cid_len = sum(len(match) for match in matches)
text_len = len(text)
if text_len == 0:
cid_chars_radio = 0
else:
cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
logger.info(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
'''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
if cid_chars_radio > 0.05:
return False # 乱码文档
else:
return True # 正常文档
def count_replacement_characters(text: str) -> int:
"""
统计字符串中 0xfffd 字符的数量。
"""
return text.count('\ufffd')
def detect_invalid_chars_by_pymupdf(src_pdf_bytes: bytes) -> bool:
sample_docs = extract_pages(src_pdf_bytes)
doc_text = ""
for page in sample_docs:
page_text = page.get_text('text', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)
doc_text += page_text
text_len = len(doc_text)
uffd_count = count_replacement_characters(doc_text)
if text_len == 0:
uffd_chars_radio = 0
else:
uffd_chars_radio = uffd_count / text_len
logger.info(f"uffd_count: {uffd_count}, text_len: {text_len}, uffd_chars_radio: {uffd_chars_radio}")
'''当一篇文章存在1%以上的文本是乱码时,认为该文档为乱码文档'''
if uffd_chars_radio > 0.01:
return False # 乱码文档
else:
return True # 正常文档
\ No newline at end of file
from io import BytesIO
import cv2
import fitz
import numpy as np
from PIL import Image
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.hash_utils import compute_sha256
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: DataWriter):
"""从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
图片存放在save_path下,文件名是:
{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
# 拼接文件名
filename = f'{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}'
# 老版本返回不带bucket的路径
img_path = join_path(return_path, filename) if return_path is not None else None
# 新版本生成平铺路径
img_hash256_path = f'{compute_sha256(img_path)}.jpg'
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
# 配置缩放倍数为3倍
zoom = fitz.Matrix(3, 3)
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
imageWriter.write(img_hash256_path, byte_data)
return img_hash256_path
def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
# 配置缩放倍数为3倍
zoom = fitz.Matrix(3, 3)
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
if mode == "cv2":
# 直接转换为numpy数组供cv2使用
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if pix.n == 3 or pix.n == 4:
image_result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
image_result = img_array
elif mode == "pillow":
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
image_result = Image.open(image_file)
else:
raise ValueError(f"mode: {mode} is not supported.")
return image_result
\ No newline at end of file
import time
import functools
from collections import defaultdict
from typing import Dict, List
class PerformanceStats:
"""性能统计类,用于收集和展示方法执行时间"""
_stats: Dict[str, List[float]] = defaultdict(list)
@classmethod
def add_execution_time(cls, func_name: str, execution_time: float):
"""添加执行时间记录"""
cls._stats[func_name].append(execution_time)
@classmethod
def get_stats(cls) -> Dict[str, dict]:
"""获取统计结果"""
results = {}
for func_name, times in cls._stats.items():
results[func_name] = {
'count': len(times),
'total_time': sum(times),
'avg_time': sum(times) / len(times),
'min_time': min(times),
'max_time': max(times)
}
return results
@classmethod
def print_stats(cls):
"""打印统计结果"""
stats = cls.get_stats()
print("\n性能统计结果:")
print("-" * 80)
print(f"{'方法名':<40} {'调用次数':>8} {'总时间(s)':>12} {'平均时间(s)':>12}")
print("-" * 80)
for func_name, data in stats.items():
print(f"{func_name:<40} {data['count']:8d} {data['total_time']:12.6f} {data['avg_time']:12.6f}")
def measure_time(func):
"""测量方法执行时间的装饰器"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
execution_time = time.time() - start_time
# 获取更详细的函数标识
if hasattr(func, "__self__"): # 实例方法
class_name = func.__self__.__class__.__name__
full_name = f"{class_name}.{func.__name__}"
elif hasattr(func, "__qualname__"): # 类方法或静态方法
full_name = func.__qualname__
else:
module_name = func.__module__
full_name = f"{module_name}.{func.__name__}"
PerformanceStats.add_execution_time(full_name, execution_time)
return result
return wrapper
\ No newline at end of file
import os
def sanitize_filename(filename, replacement="_"):
if os.name == 'nt':
invalid_chars = '<>:"|?*'
for char in invalid_chars:
filename = filename.replace(char, replacement)
return filename
__use_inside_model__ = True
__model_mode__ = 'full'
\ No newline at end of file
import time
import cv2
from loguru import logger
from tqdm import tqdm
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res, get_coords_and_area)
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze:
def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable):
self.model_manager = model_manager
self.batch_ratio = batch_ratio
self.show_log = show_log
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
return []
images_layout_res = []
layout_start_time = time.time()
self.model = self.model_manager.get_model(
ocr=True,
show_log=self.show_log,
lang = None,
layout_model = self.layout_model,
formula_enable = self.formula_enable,
table_enable = self.table_enable,
)
images = [image for image, _, _ in images_with_extra_info]
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
for image in images:
layout_res = self.model.layout_model(image, ignore_catids=[])
images_layout_res.append(layout_res)
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
layout_images = []
for image_index, image in enumerate(images):
layout_images.append(image)
images_layout_res += self.model.layout_model.batch_predict(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
# logger.info(
# f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
# )
if self.model.apply_formula:
# 公式检测
mfd_start_time = time.time()
images_mfd_res = self.model.mfd_model.batch_predict(
# images, self.batch_ratio * MFD_BASE_BATCH_SIZE
images, MFD_BASE_BATCH_SIZE
)
# logger.info(
# f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
# )
# 公式识别
mfr_start_time = time.time()
images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res,
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
mfr_count = 0
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
# logger.info(
# f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
# )
# 清理显存
# clean_vram(self.model.device, vram_threshold=8)
ocr_res_list_all_page = []
table_res_list_all_page = []
for index in range(len(images)):
_, ocr_enable, _lang = images_with_extra_info[index]
layout_res = images_layout_res[index]
np_array_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
'lang':_lang,
'ocr_enable':ocr_enable,
'np_array_img':np_array_img,
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
for table_res in table_res_list:
table_img, _ = crop_img(table_res, np_array_img)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':table_img,
})
# 文本框检测
det_start = time.time()
det_count = 0
# for ocr_res_list_dict in ocr_res_list_all_page:
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR-det
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
if res["category_id"] == 3:
# ocr_result_list中所有bbox的面积之和
ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4]
if res_area > 0:
ratio = ocr_res_area / res_area
if ratio > 0.25:
res["category_id"] = 1
else:
continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
# det_count += len(ocr_res_list_dict['ocr_res_list'])
# logger.info(f'ocr-det time: {round(time.time()-det_start, 2)}, image num: {det_count}')
# 表格识别 table recognition
if self.model.apply_table:
table_start = time.time()
# for table_res_list_dict in table_res_list_all_page:
for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_dict['lang']
atom_model_manager = AtomModelSingleton()
table_model = atom_model_manager.get_atom_model(
atom_model_name='table',
table_model_name='rapid_table',
table_model_path='',
table_max_time=400,
device='cpu',
lang=_lang,
table_sub_model_name='slanet_plus'
)
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
table_res_dict['table_res']['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
'table recognition processing fails, not get html return'
)
# logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
for layout_res in images_layout_res:
for layout_res_item in layout_res:
if layout_res_item['category_id'] in [15]:
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
lang = layout_res_item['lang']
# Initialize lists for this language if not exist
if lang not in need_ocr_lists_by_lang:
need_ocr_lists_by_lang[lang] = []
img_crop_lists_by_lang[lang] = []
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang[lang].append(layout_res_item)
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
# Remove the fields after adding to lists
layout_res_item.pop('np_img')
layout_res_item.pop('lang')
if len(img_crop_lists_by_lang) > 0:
# Process OCR by language
rec_time = 0
rec_start = time.time()
total_processed = 0
# Process each language separately
for lang, img_crop_list in img_crop_lists_by_lang.items():
if len(img_crop_list) > 0:
# Get OCR results for this language's images
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(
need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
# Process OCR results for this language
for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(f"{ocr_score:.3f}")
total_processed += len(img_crop_list)
rec_time += time.time() - rec_start
# logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
return images_layout_res
import os
import time
import numpy as np
import torch
os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from loguru import logger
from magic_pdf.model.sub_modules.model_utils import get_vram
from magic_pdf.config.enums import SupportedPdfParseMethod
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_layout_config,
get_local_models_dir,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
ocr: bool,
show_log: bool,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models:
self._models[key] = custom_model_init(
ocr=ocr,
show_log=show_log,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
return self._models[key]
def custom_model_init(
ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
model = None
if model_config.__model_mode__ == 'lite':
logger.warning(
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
'not guaranteed to be reliable.'
)
model = MODEL.Paddle
elif model_config.__model_mode__ == 'full':
model = MODEL.PEK
if model_config.__use_inside_model__:
model_init_start = time.time()
if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()
layout_config = get_layout_config()
if layout_model is not None:
layout_config['model'] = layout_model
formula_config = get_formula_config()
if formula_enable is not None:
formula_config['enable'] = formula_enable
table_config = get_table_recog_config()
if table_enable is not None:
table_config['enable'] = table_enable
model_input = {
'ocr': ocr,
'show_log': show_log,
'models_dir': local_models_dir,
'device': device,
'table_config': table_config,
'layout_config': layout_config,
'formula_config': formula_config,
'lang': lang,
}
custom_model = CustomPEKModel(**model_input)
else:
logger.error('Not allow model_name!')
exit(1)
model_init_cost = time.time() - model_init_start
logger.info(f'model init cost: {model_init_cost}')
else:
logger.error('use_inside_model is False, not allow to use inside model')
exit(1)
return custom_model
def doc_analyze(
dataset: Dataset,
ocr: bool = False,
show_log: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1
)
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
images = []
page_wh_list = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(images))]
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
else:
batch_images = [images_with_extra_info]
results = []
processed_images_count = 0
for index, batch_image in enumerate(batch_images):
processed_images_count += len(batch_image)
logger.info(f'Batch {index + 1}/{len(batch_images)}: {processed_images_count} pages/{len(images_with_extra_info)} pages')
result = may_batch_image_analyze(batch_image, ocr, show_log,layout_model, formula_enable, table_enable)
results.extend(result)
model_json = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
else:
result = []
page_height = 0
page_width = 0
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
from magic_pdf.operators.models import InferenceResult
return InferenceResult(model_json, dataset)
def batch_doc_analyze(
datasets: list[Dataset],
parse_method: str = 'auto',
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
batch_size = MIN_BATCH_INFERENCE_SIZE
page_wh_list = []
images_with_extra_info = []
for dataset in datasets:
ocr = False
if parse_method == 'auto':
if dataset.classify() == SupportedPdfParseMethod.TXT:
ocr = False
elif dataset.classify() == SupportedPdfParseMethod.OCR:
ocr = True
elif parse_method == 'ocr':
ocr = True
elif parse_method == 'txt':
ocr = False
_lang = dataset._lang
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
page_wh_list.append((img_dict['width'], img_dict['height']))
images_with_extra_info.append((img_dict['img'], ocr, _lang))
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
results = []
processed_images_count = 0
for index, batch_image in enumerate(batch_images):
processed_images_count += len(batch_image)
logger.info(f'Batch {index + 1}/{len(batch_images)}: {processed_images_count} pages/{len(images_with_extra_info)} pages')
result = may_batch_image_analyze(batch_image, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
infer_results = []
from magic_pdf.operators.models import InferenceResult
for index in range(len(datasets)):
dataset = datasets[index]
model_json = []
for i in range(len(dataset)):
result = results.pop(0)
page_width, page_height = page_wh_list.pop(0)
page_info = {'page_no': i, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
infer_results.append(InferenceResult(model_json, dataset))
return infer_results
def may_batch_image_analyze(
images_with_extra_info: list[(np.ndarray, bool, str)],
ocr: bool,
show_log: bool = False,
layout_model=None,
formula_enable=None,
table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton()
# images = [image for image, _, _ in images_with_extra_info]
batch_ratio = 1
device = get_device()
if str(device).startswith('npu'):
import torch_npu
if torch_npu.npu.is_available():
torch.npu.set_compile_mode(jit_compile=False)
if str(device).startswith('npu') or str(device).startswith('cuda'):
vram = get_vram(device)
if vram is not None:
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(vram)))
if gpu_memory >= 16:
batch_ratio = 16
elif gpu_memory >= 12:
batch_ratio = 8
elif gpu_memory >= 8:
batch_ratio = 4
elif gpu_memory >= 6:
batch_ratio = 2
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
else:
# Default batch_ratio when VRAM can't be determined
batch_ratio = 1
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
# doc_analyze_start = time.time()
batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images_with_extra_info)
# gc_start = time.time()
clean_memory(get_device())
# gc_time = round(time.time() - gc_start, 2)
# logger.debug(f'gc time: {gc_time}')
# doc_analyze_time = round(time.time() - doc_analyze_start, 2)
# doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
# logger.debug(
# f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
# f' speed: {doc_analyze_speed} pages/second'
# )
return results
\ No newline at end of file
import enum
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, bbox_distance, bbox_relative_pos,
calculate_iou)
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class PosRelationEnum(enum.Enum):
LEFT = 'left'
RIGHT = 'right'
UP = 'up'
BOTTOM = 'bottom'
ALL = 'all'
class MagicModel:
"""每个函数没有得到元素的时候返回空list."""
def __fix_axis(self):
for model_page_info in self.__model_list:
need_remove_list = []
page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs.get_page(page_no)
)
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det.get('bbox') is not None:
# 兼容直接输出bbox的模型数据,如paddle
x0, y0, x1, y1 = layout_det['bbox']
else:
# 兼容直接输出poly的模型数据,如xxx
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [
int(x0 / horizontal_scale_ratio),
int(y0 / vertical_scale_ratio),
int(x1 / horizontal_scale_ratio),
int(y1 / vertical_scale_ratio),
]
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_low_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_high_iou_and_low_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info['layout_dets']
for layout_det1 in layout_dets:
for layout_det2 in layout_dets:
if layout_det1 == layout_det2:
continue
if layout_det1['category_id'] in [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else:
continue
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: Dataset):
self.__model_list = model_list
self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def _bbox_distance(self, bbox1, bbox2):
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
count = sum([1 if v else 0 for v in flags])
if count > 1:
return float('inf')
if left or right:
l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1]
else:
l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0]
if l2 > l1 and (l2 - l1) / l1 > 0.3:
return float('inf')
return bbox_distance(bbox1, bbox2)
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list:
footnotes = []
figures = []
tables = []
for obj in model_page_info['layout_dets']:
if obj['category_id'] == 7:
footnotes.append(obj)
elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes):
N = len(bboxes)
keep = [True] * N
for i in range(N):
for j in range(N):
if i == j:
continue
if _is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance_v2(
self,
page_no: int,
subject_category_id: int,
object_category_id: int,
priority_pos: PosRelationEnum,
):
"""_summary_
Args:
page_no (int): _description_
subject_category_id (int): _description_
object_category_id (int): _description_
priority_pos (PosRelationEnum): _description_
Returns:
_type_: _description_
"""
AXIS_MULPLICITY = 0.5
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
M = len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
sub_obj_map_h = {i: [] for i in range(len(subjects))}
dis_by_directions = {
'top': [[-1, float('inf')]] * M,
'bottom': [[-1, float('inf')]] * M,
'left': [[-1, float('inf')]] * M,
'right': [[-1, float('inf')]] * M,
}
for i, obj in enumerate(objects):
l_x_axis, l_y_axis = (
obj['bbox'][2] - obj['bbox'][0],
obj['bbox'][3] - obj['bbox'][1],
)
axis_unit = min(l_x_axis, l_y_axis)
for j, sub in enumerate(subjects):
bbox1, bbox2, _ = _remove_overlap_between_bbox(
objects[i]['bbox'], subjects[j]['bbox']
)
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
if sum([1 if v else 0 for v in flags]) > 1:
continue
if left:
if dis_by_directions['left'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['left'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if right:
if dis_by_directions['right'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['right'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if bottom:
if dis_by_directions['bottom'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['bottom'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if top:
if dis_by_directions['top'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['top'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if (
dis_by_directions['top'][i][1] != float('inf')
and dis_by_directions['bottom'][i][1] != float('inf')
and priority_pos in (PosRelationEnum.BOTTOM, PosRelationEnum.UP)
):
RATIO = 3
if (
abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
)
< RATIO * axis_unit
):
if priority_pos == PosRelationEnum.BOTTOM:
sub_obj_map_h[dis_by_directions['bottom'][i][0]].append(i)
else:
sub_obj_map_h[dis_by_directions['top'][i][0]].append(i)
continue
if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
'right'
][i][1] != float('inf'):
if dis_by_directions['left'][i][1] != float(
'inf'
) and dis_by_directions['right'][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['left'][i][1]
- dis_by_directions['right'][i][1]
):
left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
'bbox'
]
right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
'bbox'
]
left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
if (
abs(left_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['left'][i][0]
> abs(right_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['right'][i][0]
):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] > dis_by_directions['right'][i][1]:
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] == float('inf'):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = [-1, float('inf')]
if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
'bottom'
][i][1] != float('inf'):
if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
'bottom'
][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
):
top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
top_bottom_x_axis = top_bottom[2] - top_bottom[0]
bottom_top_x_axis = bottom_top[2] - bottom_top[0]
if (
abs(top_bottom_x_axis - l_x_axis)
+ dis_by_directions['bottom'][i][1]
> abs(bottom_top_x_axis - l_x_axis)
+ dis_by_directions['top'][i][1]
):
top_or_bottom = dis_by_directions['top'][i]
else:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] == float('inf'):
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = [-1, float('inf')]
if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
'inf'
):
if AXIS_MULPLICITY * axis_unit >= abs(
left_or_right[1] - top_or_bottom[1]
):
y_axis_bbox = subjects[left_or_right[0]]['bbox']
x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
if (
abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
> abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
/ l_y_axis
):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
if left_or_right[1] > top_or_bottom[1]:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
sub_obj_map_h[left_or_right[0]].append(i)
else:
if left_or_right[1] != float('inf'):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
ret = []
for i in sub_obj_map_h.keys():
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [
{'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
for j in sub_obj_map_h[i]
],
'sub_idx': i,
}
)
return ret
def __tie_up_category_by_distance_v3(
self,
page_no: int,
subject_category_id: int,
object_category_id: int,
priority_pos: PosRelationEnum,
):
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
ret = []
N, M = len(subjects), len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
OBJ_IDX_OFFSET = 10000
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
seen_idx = set()
seen_sub_idx = set()
while N > len(seen_sub_idx):
candidates = []
for idx, kind, x0, y0 in all_boxes_with_idx:
if idx in seen_idx:
continue
candidates.append((idx, kind, x0, y0))
if len(candidates) == 0:
break
left_x = min([v[2] for v in candidates])
top_y = min([v[3] for v in candidates])
candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
fst_idx, fst_kind, left_x, top_y = candidates[0]
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
nxt = None
for i in range(1, len(candidates)):
if candidates[i][1] ^ fst_kind == 1:
nxt = candidates[i]
break
if nxt is None:
break
if fst_kind == SUB_BIT_KIND:
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
else:
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
nearest_dis = float('inf')
for i in range(N):
if i in seen_idx or i == sub_idx:continue
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
if pair_dis >= 3*nearest_dis:
seen_idx.add(sub_idx)
continue
seen_idx.add(sub_idx)
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
seen_sub_idx.add(sub_idx)
ret.append(
{
'sub_bbox': {
'bbox': subjects[sub_idx]['bbox'],
'score': subjects[sub_idx]['score'],
},
'obj_bboxes': [
{'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
],
'sub_idx': sub_idx,
}
)
for i in range(len(objects)):
j = i + OBJ_IDX_OFFSET
if j in seen_idx:
continue
seen_idx.add(j)
nearest_dis, nearest_sub_idx = float('inf'), -1
for k in range(len(subjects)):
dis = bbox_distance(objects[i]['bbox'], subjects[k]['bbox'])
if dis < nearest_dis:
nearest_dis = dis
nearest_sub_idx = k
for k in range(len(subjects)):
if k != nearest_sub_idx: continue
if k in seen_sub_idx:
for kk in range(len(ret)):
if ret[kk]['sub_idx'] == k:
ret[kk]['obj_bboxes'].append({'score': objects[i]['score'], 'bbox': objects[i]['bbox']})
break
else:
ret.append(
{
'sub_bbox': {
'bbox': subjects[k]['bbox'],
'score': subjects[k]['score'],
},
'obj_bboxes': [
{'score': objects[i]['score'], 'bbox': objects[i]['bbox']}
],
'sub_idx': k,
}
)
seen_sub_idx.add(k)
seen_idx.add(k)
for i in range(len(subjects)):
if i in seen_sub_idx:
continue
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [],
'sub_idx': i,
}
)
return ret
def get_imgs_v2(self, page_no: int):
with_captions = self.__tie_up_category_by_distance_v3(
page_no, 3, 4, PosRelationEnum.BOTTOM
)
with_footnotes = self.__tie_up_category_by_distance_v3(
page_no, 3, CategoryId.ImageFootnote, PosRelationEnum.ALL
)
ret = []
for v in with_captions:
record = {
'image_body': v['sub_bbox'],
'image_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables_v2(self, page_no: int) -> list:
with_captions = self.__tie_up_category_by_distance_v3(
page_no, 5, 6, PosRelationEnum.UP
)
with_footnotes = self.__tie_up_category_by_distance_v3(
page_no, 5, 7, PosRelationEnum.ALL
)
ret = []
for v in with_captions:
record = {
'table_body': v['sub_bbox'],
'table_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_imgs(self, page_no: int):
return self.get_imgs_v2(page_no)
def get_tables(
self, page_no: int
) -> list: # 3个坐标, caption, table主体,table-note
return self.get_tables_v2(page_no)
def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.EMBEDDING.value, page_no, ['latex']
)
interline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATED.value, page_no, ['latex']
)
interline_equations_blocks = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
)
return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ABANDON.value, page_no)
return blocks
def get_text_blocks(self, page_no: int) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.PLAIN_TEXT.value, page_no)
return blocks
def get_title_blocks(self, page_no: int) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.TITLE.value, page_no)
return blocks
def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标
text_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det['category_id'] == '15':
span = {
'bbox': layout_det['bbox'],
'content': layout_det['text'],
}
text_spans.append(span)
return text_spans
def get_all_spans(self, page_no: int) -> list:
def remove_duplicate_spans(spans):
new_spans = []
for span in spans:
if not any(span == existing_span for existing_span in new_spans):
new_spans.append(span)
return new_spans
all_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets:
category_id = layout_det['category_id']
if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3:
span['type'] = ContentType.Image
elif category_id == 5:
# 获取table模型结果
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
if latex:
span['latex'] = latex
elif html:
span['html'] = html
span['type'] = ContentType.Table
elif category_id == 13:
span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation
elif category_id == 14:
span['content'] = layout_det['latex']
span['type'] = ContentType.InterlineEquation
elif category_id == 15:
span['content'] = layout_det['text']
span['type'] = ContentType.Text
all_spans.append(span)
return remove_duplicate_spans(all_spans)
def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象
page = self.__docs.get_page(page_no).get_page_info()
# 获取当前页的宽高
page_w = page.w
page_h = page.h
return page_w, page_h
def __get_blocks_by_type(
self, type: int, page_no: int, extra_col: list[str] = []
) -> list:
blocks = []
for page_dict in self.__model_list:
layout_dets = page_dict.get('layout_dets', [])
page_info = page_dict.get('page_info', {})
page_number = page_info.get('page_no', -1)
if page_no != page_number:
continue
for item in layout_dets:
category_id = item.get('category_id', -1)
bbox = item.get('bbox', None)
if category_id == type:
block = {
'bbox': bbox,
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks
def get_model_list(self, page_no):
return self.__model_list[page_no]
class MODEL:
Paddle = "pp_structure_v2"
PEK = "pdf_extract_kit"
class AtomicModel:
Layout = "layout"
MFD = "mfd"
MFR = "mfr"
OCR = "ocr"
Table = "table"
LangDetect = "langdetect"
# flake8: noqa
import os
import time
import cv2
import torch
import yaml
from loguru import logger
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
"""
======== model init ========
"""
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径
current_file_path = os.path.abspath(__file__)
# 获取当前文件所在的目录(model)
current_dir = os.path.dirname(current_file_path)
# 上一级目录(magic_pdf)
root_dir = os.path.dirname(current_dir)
# model_config目录
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, 'r', encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置
# layout config
self.layout_config = kwargs.get('layout_config')
self.layout_model_name = self.layout_config.get(
'model', MODEL_NAME.DocLayout_YOLO
)
# formula config
self.formula_config = kwargs.get('formula_config')
self.mfd_model_name = self.formula_config.get(
'mfd_model', MODEL_NAME.YOLO_V8_MFD
)
self.mfr_model_name = self.formula_config.get(
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
)
self.apply_formula = self.formula_config.get('enable', True)
# table config
self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
self.table_sub_model_name = self.table_config.get('sub_model', None)
# ocr config
self.apply_ocr = ocr
self.lang = kwargs.get('lang', None)
logger.info(
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
'apply_table: {}, table_model: {}, lang: {}'.format(
self.layout_model_name,
self.apply_formula,
self.apply_ocr,
self.apply_table,
self.table_model_name,
self.lang,
)
)
# 初始化解析方案
self.device = kwargs.get('device', 'cpu')
logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get(
'models_dir', os.path.join(root_dir, 'resources', 'models')
)
logger.info('using models_dir: {}'.format(models_dir))
atom_model_manager = AtomModelSingleton()
# 初始化公式识别
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.mfd_model_name]
)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = str(
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
)
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
device=self.device,
)
# 初始化layout模型
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.LAYOUTLMv3,
layout_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
layout_config_file=str(
os.path.join(
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
)
),
device='cpu' if str(self.device).startswith("mps") else self.device,
)
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
device=self.device,
)
# 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
table_model_dir = self.configs['weights'][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device,
ocr_engine=self.ocr_model,
table_sub_model_name=self.table_sub_model_name
)
logger.info('DocAnalysis init done!')
def __call__(self, image):
# layout检测
layout_start = time.time()
layout_res = []
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f'layout detection time: {layout_cost}')
if self.apply_formula:
# 公式检测
mfd_start = time.time()
mfd_res = self.mfd_model.predict(image)
logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
# 公式识别
mfr_start = time.time()
formula_list = self.mfr_model.predict(mfd_res, image)
layout_res.extend(formula_list)
mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存
clean_vram(self.device, vram_threshold=6)
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
else:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list)
ocr_cost = round(time.time() - ocr_start, 2)
if self.apply_ocr:
logger.info(f"ocr time: {ocr_cost}")
else:
logger.info(f"det time: {ocr_cost}")
# 表格识别 table recognition
if self.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, image)
single_table_start_time = time.time()
html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.table_model.predict(new_image, 'html')
if len(table_result) > 0:
html_code = table_result[0]
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time
if run_time > self.table_max_time:
logger.warning(
f'table recognition processing exceeds max time {self.table_max_time}s'
)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
res['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
'table recognition processing fails, not get html return'
)
logger.info(f'table time: {round(time.time() - table_start, 2)}')
return layout_res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment