Commit e52bd023 authored by myhloli's avatar myhloli
Browse files

Merge remote-tracking branch 'origin/dev' into dev

# Conflicts:
#	magic_pdf/model/pdf_extract_kit.py
parents b2e37a2d ea2f8ea0
""" """span维度自定义字段."""
span维度自定义字段
"""
# span是否是跨页合并的 # span是否是跨页合并的
CROSS_PAGE = "cross_page" CROSS_PAGE = 'cross_page'
""" """
block维度自定义字段 block维度自定义字段
""" """
# block中lines是否被删除 # block中lines是否被删除
LINES_DELETED = "lines_deleted" LINES_DELETED = 'lines_deleted'
# table recognition max time default value # table recognition max time default value
TABLE_MAX_TIME_VALUE = 400 TABLE_MAX_TIME_VALUE = 400
...@@ -17,39 +15,39 @@ TABLE_MAX_TIME_VALUE = 400 ...@@ -17,39 +15,39 @@ TABLE_MAX_TIME_VALUE = 400
TABLE_MAX_LEN = 480 TABLE_MAX_LEN = 480
# table master structure dict # table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt" TABLE_MASTER_DICT = 'table_master_structure_dict.txt'
# table master dir # table master dir
TABLE_MASTER_DIR = "table_structure_tablemaster_infer/" TABLE_MASTER_DIR = 'table_structure_tablemaster_infer/'
# pp detect model dir # pp detect model dir
DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer" DETECT_MODEL_DIR = 'ch_PP-OCRv4_det_infer'
# pp rec model dir # pp rec model dir
REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer" REC_MODEL_DIR = 'ch_PP-OCRv4_rec_infer'
# pp rec char dict path # pp rec char dict path
REC_CHAR_DICT = "ppocr_keys_v1.txt" REC_CHAR_DICT = 'ppocr_keys_v1.txt'
# pp rec copy rec directory # pp rec copy rec directory
PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer" PP_REC_DIRECTORY = '.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer'
# pp rec copy det directory # pp rec copy det directory
PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer" PP_DET_DIRECTORY = '.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer'
class MODEL_NAME: class MODEL_NAME:
# pp table structure algorithm # pp table structure algorithm
TABLE_MASTER = "tablemaster" TABLE_MASTER = 'tablemaster'
# struct eqtable # struct eqtable
STRUCT_EQTABLE = "struct_eqtable" STRUCT_EQTABLE = 'struct_eqtable'
DocLayout_YOLO = "doclayout_yolo" DocLayout_YOLO = 'doclayout_yolo'
LAYOUTLMv3 = "layoutlmv3" LAYOUTLMv3 = 'layoutlmv3'
YOLO_V8_MFD = "yolo_v8_mfd" YOLO_V8_MFD = 'yolo_v8_mfd'
UniMerNet_v2_Small = "unimernet_small" UniMerNet_v2_Small = 'unimernet_small'
RAPID_TABLE = "rapid_table" RAPID_TABLE = 'rapid_table'
\ No newline at end of file
class DropReason:
TEXT_BLCOK_HOR_OVERLAP = 'text_block_horizontal_overlap' # 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP = (
'useful_block_horizontal_overlap' # 需保留的block水平覆盖
)
COMPLICATED_LAYOUT = 'complicated_layout' # 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS = 'too_many_layout_columns' # 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX = 'color_background_text_box' # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS = (
'high_computational_load_by_imgs' # 含特殊图片,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = (
'high_computational_load_by_svgs' # 特殊的SVG图,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = 'high_computational_load_by_total_pages' # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = 'missing doc_layout_result' # 版面分析失败
Exception = '_exception' # 解析中发生异常
ENCRYPTED = 'encrypted' # PDF是加密的
EMPTY_PDF = 'total_page=0' # PDF页面总数为0
NOT_IS_TEXT_PDF = 'not_is_text_pdf' # 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK = 'dense_single_line_block' # 无法清晰的分段
TITLE_DETECTION_FAILED = 'title_detection_failed' # 探测标题失败
TITLE_LEVEL_FAILED = (
'title_level_failed' # 分析标题级别失败(例如一级、二级、三级标题)
)
PARA_SPLIT_FAILED = 'para_split_failed' # 识别段落失败
PARA_MERGE_FAILED = 'para_merge_failed' # 段落合并失败
NOT_ALLOW_LANGUAGE = 'not_allow_language' # 不支持的语种
SPECIAL_PDF = 'special_pdf'
PSEUDO_SINGLE_COLUMN = 'pseudo_single_column' # 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT = 'can_not_detect_page_layout' # 无法分析页面的版面
NEGATIVE_BBOX_AREA = 'negative_bbox_area' # 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION = (
'overlap_blocks_can_t_separation' # 无法分离重叠的block
)
COLOR_BG_HEADER_TXT_BLOCK = 'color_background_header_txt_block'
PAGE_NO = 'page-no' # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = 'page_no'
HEADER = 'header'
FOOTER = 'footer'
FOOTNOTE = 'footnote'
NOT_IN_LAYOUT = 'not_in_layout'
SPAN_OVERLAP = 'span_overlap'
BLOCK_OVERLAP = 'block_overlap'
class MakeMode:
MM_MD = 'mm_markdown'
NLP_MD = 'nlp_markdown'
STANDARD_FORMAT = 'standard_format'
class DropMode:
WHOLE_PDF = 'whole_pdf'
SINGLE_PAGE = 'single_page'
NONE = 'none'
NONE_WITH_REASON = 'none_with_reason'
from enum import Enum from enum import Enum
class ModelBlockTypeEnum(Enum): class ModelBlockTypeEnum(Enum):
TITLE = 0 TITLE = 0
PLAIN_TEXT = 1 PLAIN_TEXT = 1
......
...@@ -35,7 +35,7 @@ def read_jsonl( ...@@ -35,7 +35,7 @@ def read_jsonl(
jsonl_d = [ jsonl_d = [
json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip() json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
] ]
for d in jsonl_d[:5]: for d in jsonl_d:
pdf_path = d.get('file_location', '') or d.get('path', '') pdf_path = d.get('file_location', '') or d.get('path', '')
if len(pdf_path) == 0: if len(pdf_path) == 0:
raise EmptyData('pdf file location is empty') raise EmptyData('pdf file location is empty')
......
This diff is collapsed.
...@@ -2,17 +2,16 @@ import re ...@@ -2,17 +2,16 @@ import re
from loguru import logger from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import join_path
from magic_pdf.libs.language import detect_lang from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.para.para_split_v3 import ListLineTag from magic_pdf.para.para_split_v3 import ListLineTag
def __is_hyphen_at_line_end(line): def __is_hyphen_at_line_end(line):
""" """Check if a line ends with one or more letters followed by a hyphen.
Check if a line ends with one or more letters followed by a hyphen.
Args: Args:
line (str): The line of text to check. line (str): The line of text to check.
...@@ -163,7 +162,7 @@ def merge_para_with_text(para_block): ...@@ -163,7 +162,7 @@ def merge_para_with_text(para_block):
if span_type in [ContentType.Text, ContentType.InterlineEquation]: if span_type in [ContentType.Text, ContentType.InterlineEquation]:
para_text += content # 中文/日语/韩文语境下,content间不需要空格分隔 para_text += content # 中文/日语/韩文语境下,content间不需要空格分隔
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
para_text += f" {content} " para_text += f' {content} '
else: else:
if span_type in [ContentType.Text, ContentType.InlineEquation]: if span_type in [ContentType.Text, ContentType.InlineEquation]:
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除 # 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
...@@ -172,7 +171,7 @@ def merge_para_with_text(para_block): ...@@ -172,7 +171,7 @@ def merge_para_with_text(para_block):
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit(): elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
para_text += content para_text += content
else: # 西方文本语境下 content间需要空格分隔 else: # 西方文本语境下 content间需要空格分隔
para_text += f"{content} " para_text += f'{content} '
elif span_type == ContentType.InterlineEquation: elif span_type == ContentType.InterlineEquation:
para_text += content para_text += content
else: else:
......
""" """输入: s3路径,每行一个 输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置."""
输入: s3路径,每行一个
输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置
"""
import sys import sys
import click from collections import Counter
from magic_pdf.libs.commons import read_file, mymax, get_top_percent_list import click
from magic_pdf.libs.commons import fitz
from loguru import logger from loguru import logger
from collections import Counter
from magic_pdf.libs.drop_reason import DropReason from magic_pdf.config.drop_reason import DropReason
from magic_pdf.libs.commons import fitz, get_top_percent_list, mymax, read_file
from magic_pdf.libs.language import detect_lang from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.pdf_check import detect_invalid_chars from magic_pdf.libs.pdf_check import detect_invalid_chars
...@@ -19,8 +16,10 @@ junk_limit_min = 10 ...@@ -19,8 +16,10 @@ junk_limit_min = 10
def calculate_max_image_area_per_page(result: list, page_width_pts, page_height_pts): def calculate_max_image_area_per_page(result: list, page_width_pts, page_height_pts):
max_image_area_per_page = [mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz]) for page_img_sz in max_image_area_per_page = [
result] mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz])
for page_img_sz in result
]
page_area = int(page_width_pts) * int(page_height_pts) page_area = int(page_width_pts) * int(page_height_pts)
max_image_area_per_page = [area / page_area for area in max_image_area_per_page] max_image_area_per_page = [area / page_area for area in max_image_area_per_page]
max_image_area_per_page = [area for area in max_image_area_per_page if area > 0.6] max_image_area_per_page = [area for area in max_image_area_per_page if area > 0.6]
...@@ -33,7 +32,9 @@ def process_image(page, junk_img_bojids=[]): ...@@ -33,7 +32,9 @@ def process_image(page, junk_img_bojids=[]):
dedup = set() dedup = set()
for img in items: for img in items:
# 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是 # 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
img_bojid = img[0] # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等 img_bojid = img[
0
] # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
if img_bojid in junk_img_bojids: # 如果是垃圾图像,就跳过 if img_bojid in junk_img_bojids: # 如果是垃圾图像,就跳过
continue continue
recs = page.get_image_rects(img, transform=True) recs = page.get_image_rects(img, transform=True)
...@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]): ...@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]):
x0, y0, x1, y1 = map(int, rec) x0, y0, x1, y1 = map(int, rec)
width = x1 - x0 width = x1 - x0
height = y1 - y0 height = y1 - y0
if (x0, y0, x1, y1, img_bojid) in dedup: # 这里面会出现一些重复的bbox,无需重复出现,需要去掉 if (
x0,
y0,
x1,
y1,
img_bojid,
) in dedup: # 这里面会出现一些重复的bbox,无需重复出现,需要去掉
continue continue
if not all([width, height]): # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义 if not all(
[width, height]
): # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
continue continue
dedup.add((x0, y0, x1, y1, img_bojid)) dedup.add((x0, y0, x1, y1, img_bojid))
page_result.append([x0, y0, x1, y1, img_bojid]) page_result.append([x0, y0, x1, y1, img_bojid])
...@@ -52,8 +61,8 @@ def process_image(page, junk_img_bojids=[]): ...@@ -52,8 +61,8 @@ def process_image(page, junk_img_bojids=[]):
def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
""" """返回每个页面里的图片的四元组,每个页面多个图片。
返回每个页面里的图片的四元组,每个页面多个图片。
:param doc: :param doc:
:return: :return:
""" """
...@@ -63,13 +72,17 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: ...@@ -63,13 +72,17 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
junk_limit = max(len(doc) * 0.5, junk_limit_min) # 对一些页数比较少的进行豁免 junk_limit = max(len(doc) * 0.5, junk_limit_min) # 对一些页数比较少的进行豁免
junk_img_bojids = [img_bojid for img_bojid, count in img_bojid_counter.items() if count >= junk_limit] junk_img_bojids = [
img_bojid
#todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多 for img_bojid, count in img_bojid_counter.items()
#有两种扫描版,一种文字版,这里可能会有误判 if count >= junk_limit
#扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张 ]
#扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
#文字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist # todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
# 有两种扫描版,一种文字版,这里可能会有误判
# 扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
# 扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
# 文 字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
imgs_len_list = [len(page.get_images()) for page in doc] imgs_len_list = [len(page.get_images()) for page in doc]
special_limit_pages = 10 special_limit_pages = 10
...@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: ...@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
break break
if i >= special_limit_pages: if i >= special_limit_pages:
break break
page_result = process_image(page) # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析 page_result = process_image(
page
) # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
result.append(page_result) result.append(page_result)
for item in result: for item in result:
if not any(item): # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版 if not any(
if max(imgs_len_list) == min(imgs_len_list) and max( item
imgs_len_list) >= junk_limit_min: # 如果是特殊文字版,就把junklist置空并break ): # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if (
max(imgs_len_list) == min(imgs_len_list)
and max(imgs_len_list) >= junk_limit_min
): # 如果是特殊文字版,就把junklist置空并break
junk_img_bojids = [] junk_img_bojids = []
else: # 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist else: # 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist
pass pass
...@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: ...@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
top_eighty_percent = get_top_percent_list(imgs_len_list, 0.8) top_eighty_percent = get_top_percent_list(imgs_len_list, 0.8)
# 检查前80%的元素是否都相等 # 检查前80%的元素是否都相等
if len(set(top_eighty_percent)) == 1 and max(imgs_len_list) >= junk_limit_min: if len(set(top_eighty_percent)) == 1 and max(imgs_len_list) >= junk_limit_min:
# # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist # # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist
# if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min: # if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min:
#前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist # 前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
max_image_area_per_page = calculate_max_image_area_per_page(result, page_width_pts, page_height_pts) max_image_area_per_page = calculate_max_image_area_per_page(
if len(max_image_area_per_page) < 0.8 * special_limit_pages: # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空 result, page_width_pts, page_height_pts
)
if (
len(max_image_area_per_page) < 0.8 * special_limit_pages
): # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
junk_img_bojids = [] junk_img_bojids = []
else: # 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist else: # 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist
pass pass
else: # 每页图片数量不一致,需要清掉junklist全量跑前50页图片 else: # 每页图片数量不一致,需要清掉junklist全量跑前50页图片
junk_img_bojids = [] junk_img_bojids = []
#正式进入取前50页图片的信息流程 # 正式进入取前50页图片的信息流程
result = [] result = []
for i, page in enumerate(doc): for i, page in enumerate(doc):
if i >= scan_max_page: if i >= scan_max_page:
...@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: ...@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
def get_pdf_page_size_pts(doc: fitz.Document): def get_pdf_page_size_pts(doc: fitz.Document):
page_cnt = len(doc) page_cnt = len(doc)
l: int = min(page_cnt, 50) l: int = min(page_cnt, 50)
#把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了) # 把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
page_width_list = [] page_width_list = []
page_height_list = [] page_height_list = []
for i in range(l): for i in range(l):
...@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document): ...@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
# 拿所有text的blocks # 拿所有text的blocks
# text_block = page.get_text("words") # text_block = page.get_text("words")
# text_block_len = sum([len(t[4]) for t in text_block]) # text_block_len = sum([len(t[4]) for t in text_block])
#拿所有text的str # 拿所有text的str
text_block = page.get_text("text") text_block = page.get_text('text')
text_block_len = len(text_block) text_block_len = len(text_block)
# logger.info(f"page {page.number} text_block_len: {text_block_len}") # logger.info(f"page {page.number} text_block_len: {text_block_len}")
text_len_lst.append(text_block_len) text_len_lst.append(text_block_len)
...@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document): ...@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
def get_pdf_text_layout_per_page(doc: fitz.Document): def get_pdf_text_layout_per_page(doc: fitz.Document):
""" """根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
Args: Args:
doc (fitz.Document): PDF文档对象。 doc (fitz.Document): PDF文档对象。
Returns: Returns:
List[str]: 每一页的文本布局(横向、纵向、未知)。 List[str]: 每一页的文本布局(横向、纵向、未知)。
""" """
text_layout_list = [] text_layout_list = []
...@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document): ...@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# 创建每一页的纵向和横向的文本行数计数器 # 创建每一页的纵向和横向的文本行数计数器
vertical_count = 0 vertical_count = 0
horizontal_count = 0 horizontal_count = 0
text_dict = page.get_text("dict") text_dict = page.get_text('dict')
if "blocks" in text_dict: if 'blocks' in text_dict:
for block in text_dict["blocks"]: for block in text_dict['blocks']:
if 'lines' in block: if 'lines' in block:
for line in block["lines"]: for line in block['lines']:
# 获取line的bbox顶点坐标 # 获取line的bbox顶点坐标
x0, y0, x1, y1 = line['bbox'] x0, y0, x1, y1 = line['bbox']
# 计算bbox的宽高 # 计算bbox的宽高
...@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document): ...@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
if len(font_sizes) > 0: if len(font_sizes) > 0:
average_font_size = sum(font_sizes) / len(font_sizes) average_font_size = sum(font_sizes) / len(font_sizes)
else: else:
average_font_size = 10 # 有的line拿不到font_size,先定一个阈值100 average_font_size = (
if area <= average_font_size ** 2: # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向 10 # 有的line拿不到font_size,先定一个阈值100
)
if (
area <= average_font_size**2
): # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
continue continue
else: else:
if 'wmode' in line: # 通过wmode判断文本方向 if 'wmode' in line: # 通过wmode判断文本方向
...@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document): ...@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}") # print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
# 判断每一页的文本布局 # 判断每一页的文本布局
if vertical_count == 0 and horizontal_count == 0: # 该页没有文本,无法判断 if vertical_count == 0 and horizontal_count == 0: # 该页没有文本,无法判断
text_layout_list.append("unknow") text_layout_list.append('unknow')
continue continue
else: else:
if vertical_count > horizontal_count: # 该页的文本纵向行数大于横向的 if vertical_count > horizontal_count: # 该页的文本纵向行数大于横向的
text_layout_list.append("vertical") text_layout_list.append('vertical')
else: # 该页的文本横向行数大于纵向的 else: # 该页的文本横向行数大于纵向的
text_layout_list.append("horizontal") text_layout_list.append('horizontal')
# logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}") # logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
return text_layout_list return text_layout_list
'''定义一个自定义异常用来抛出单页svg太多的pdf''' """定义一个自定义异常用来抛出单页svg太多的pdf"""
class PageSvgsTooManyError(Exception): class PageSvgsTooManyError(Exception):
def __init__(self, message="Page SVGs are too many"): def __init__(self, message='Page SVGs are too many'):
self.message = message self.message = message
super().__init__(self.message) super().__init__(self.message)
...@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document): ...@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document):
if page_id >= scan_max_page: if page_id >= scan_max_page:
break break
# 拿所有text的str # 拿所有text的str
text_block = page.get_text("text") text_block = page.get_text('text')
page_language = detect_lang(text_block) page_language = detect_lang(text_block)
language_lst.append(page_language) language_lst.append(page_language)
...@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document): ...@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document):
def check_invalid_chars(pdf_bytes): def check_invalid_chars(pdf_bytes):
""" """乱码检测."""
乱码检测
"""
return detect_invalid_chars(pdf_bytes) return detect_invalid_chars(pdf_bytes)
...@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes):
:param pdf_bytes: pdf文件的二进制数据 :param pdf_bytes: pdf文件的二进制数据
几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取 几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取
""" """
doc = fitz.open("pdf", pdf_bytes) doc = fitz.open('pdf', pdf_bytes)
is_needs_password = doc.needs_pass is_needs_password = doc.needs_pass
is_encrypted = doc.is_encrypted is_encrypted = doc.is_encrypted
total_page = len(doc) total_page = len(doc)
if total_page == 0: if total_page == 0:
logger.warning(f"drop this pdf, drop_reason: {DropReason.EMPTY_PDF}") logger.warning(f'drop this pdf, drop_reason: {DropReason.EMPTY_PDF}')
result = {"_need_drop": True, "_drop_reason": DropReason.EMPTY_PDF} result = {'_need_drop': True, '_drop_reason': DropReason.EMPTY_PDF}
return result return result
else: else:
page_width_pts, page_height_pts = get_pdf_page_size_pts(doc) page_width_pts, page_height_pts = get_pdf_page_size_pts(doc)
...@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
imgs_per_page = get_imgs_per_page(doc) imgs_per_page = get_imgs_per_page(doc)
# logger.info(f"imgs_per_page: {imgs_per_page}") # logger.info(f"imgs_per_page: {imgs_per_page}")
image_info_per_page, junk_img_bojids = get_image_info(doc, page_width_pts, page_height_pts) image_info_per_page, junk_img_bojids = get_image_info(
doc, page_width_pts, page_height_pts
)
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}") # logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
text_len_per_page = get_pdf_textlen_per_page(doc) text_len_per_page = get_pdf_textlen_per_page(doc)
# logger.info(f"text_len_per_page: {text_len_per_page}") # logger.info(f"text_len_per_page: {text_len_per_page}")
...@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes):
# 最后输出一条json # 最后输出一条json
res = { res = {
"is_needs_password": is_needs_password, 'is_needs_password': is_needs_password,
"is_encrypted": is_encrypted, 'is_encrypted': is_encrypted,
"total_page": total_page, 'total_page': total_page,
"page_width_pts": int(page_width_pts), 'page_width_pts': int(page_width_pts),
"page_height_pts": int(page_height_pts), 'page_height_pts': int(page_height_pts),
"image_info_per_page": image_info_per_page, 'image_info_per_page': image_info_per_page,
"text_len_per_page": text_len_per_page, 'text_len_per_page': text_len_per_page,
"text_layout_per_page": text_layout_per_page, 'text_layout_per_page': text_layout_per_page,
"text_language": text_language, 'text_language': text_language,
# "svgs_per_page": svgs_per_page, # "svgs_per_page": svgs_per_page,
"imgs_per_page": imgs_per_page, # 增加每页img数量list 'imgs_per_page': imgs_per_page, # 增加每页img数量list
"junk_img_bojids": junk_img_bojids, # 增加垃圾图片的bojid list 'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
"invalid_chars": invalid_chars, 'invalid_chars': invalid_chars,
"metadata": doc.metadata 'metadata': doc.metadata,
} }
# logger.info(json.dumps(res, ensure_ascii=False)) # logger.info(json.dumps(res, ensure_ascii=False))
return res return res
...@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes):
@click.option('--s3-pdf-path', help='s3上pdf文件的路径') @click.option('--s3-pdf-path', help='s3上pdf文件的路径')
@click.option('--s3-profile', help='s3上的profile') @click.option('--s3-profile', help='s3上的profile')
def main(s3_pdf_path: str, s3_profile: str): def main(s3_pdf_path: str, s3_profile: str):
""" """"""
"""
try: try:
file_content = read_file(s3_pdf_path, s3_profile) file_content = read_file(s3_pdf_path, s3_profile)
pdf_meta_scan(file_content) pdf_meta_scan(file_content)
except Exception as e: except Exception as e:
print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr) print(f'ERROR: {s3_pdf_path}, {e}', file=sys.stderr)
logger.exception(e) logger.exception(e)
...@@ -381,7 +403,7 @@ if __name__ == '__main__': ...@@ -381,7 +403,7 @@ if __name__ == '__main__':
# "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf" # "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf" # "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf"
# "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf" # "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf"
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","") # file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","") # noqa: E501
# file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","") # file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","")
# doc = fitz.open("pdf", file_content) # doc = fitz.open("pdf", file_content)
# text_layout_lst = get_pdf_text_layout_per_page(doc) # text_layout_lst = get_pdf_text_layout_per_page(doc)
......
...@@ -5,14 +5,13 @@ from pathlib import Path ...@@ -5,14 +5,13 @@ from pathlib import Path
from loguru import logger from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
from magic_pdf.integrations.rag.type import (CategoryType, ContentObject, from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
ElementRelation, ElementRelType, ElementRelation, ElementRelType,
LayoutElements, LayoutElements,
LayoutElementsExtra, PageInfo) LayoutElementsExtra, PageInfo)
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.tools.common import do_parse, prepare_env from magic_pdf.tools.common import do_parse, prepare_env
...@@ -224,8 +223,8 @@ def inference(path, output_dir, method): ...@@ -224,8 +223,8 @@ def inference(path, output_dir, method):
str(Path(path).stem), method) str(Path(path).stem), method)
def read_fn(path): def read_fn(path):
disk_rw = DiskReaderWriter(os.path.dirname(path)) disk_rw = FileBasedDataReader(os.path.dirname(path))
return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN) return disk_rw.read(os.path.basename(path))
def parse_doc(doc_path: str): def parse_doc(doc_path: str):
try: try:
......
class MakeMode:
MM_MD = "mm_markdown"
NLP_MD = "nlp_markdown"
STANDARD_FORMAT = "standard_format"
class DropMode:
WHOLE_PDF = "whole_pdf"
SINGLE_PAGE = "single_page"
NONE = "none"
NONE_WITH_REASON = "none_with_reason"
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
from loguru import logger from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量 # 定义配置文件名常量
...@@ -99,7 +99,7 @@ def get_table_recog_config(): ...@@ -99,7 +99,7 @@ def get_table_recog_config():
def get_layout_config(): def get_layout_config():
config = read_config() config = read_config()
layout_config = config.get("layout-config") layout_config = config.get('layout-config')
if layout_config is None: if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default") logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}') return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
...@@ -109,7 +109,7 @@ def get_layout_config(): ...@@ -109,7 +109,7 @@ def get_layout_config():
def get_formula_config(): def get_formula_config():
config = read_config() config = read_config()
formula_config = config.get("formula-config") formula_config = config.get('formula-config')
if formula_config is None: if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default") logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}') return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
...@@ -117,5 +117,5 @@ def get_formula_config(): ...@@ -117,5 +117,5 @@ def get_formula_config():
return formula_config return formula_config
if __name__ == "__main__": if __name__ == '__main__':
ak, sk, endpoint = get_s3_config("llm-raw") ak, sk, endpoint = get_s3_config('llm-raw')
from magic_pdf.config.constants import CROSS_PAGE
from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
ContentType)
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.libs.commons import fitz # PyMuPDF from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.libs.Constants import CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
......
class DropReason:
TEXT_BLCOK_HOR_OVERLAP = "text_block_horizontal_overlap" # 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP = "useful_block_horizontal_overlap" # 需保留的block水平覆盖
COMPLICATED_LAYOUT = "complicated_layout" # 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS = "too_many_layout_columns" # 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX = "color_background_text_box" # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS = "high_computational_load_by_imgs" # 含特殊图片,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = "high_computational_load_by_svgs" # 特殊的SVG图,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = "high_computational_load_by_total_pages" # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = "missing doc_layout_result" # 版面分析失败
Exception = "_exception" # 解析中发生异常
ENCRYPTED = "encrypted" # PDF是加密的
EMPTY_PDF = "total_page=0" # PDF页面总数为0
NOT_IS_TEXT_PDF = "not_is_text_pdf" # 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK = "dense_single_line_block" # 无法清晰的分段
TITLE_DETECTION_FAILED = "title_detection_failed" # 探测标题失败
TITLE_LEVEL_FAILED = "title_level_failed" # 分析标题级别失败(例如一级、二级、三级标题)
PARA_SPLIT_FAILED = "para_split_failed" # 识别段落失败
PARA_MERGE_FAILED = "para_merge_failed" # 段落合并失败
NOT_ALLOW_LANGUAGE = "not_allow_language" # 不支持的语种
SPECIAL_PDF = "special_pdf"
PSEUDO_SINGLE_COLUMN = "pseudo_single_column" # 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT="can_not_detect_page_layout" # 无法分析页面的版面
NEGATIVE_BBOX_AREA = "negative_bbox_area" # 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION = "overlap_blocks_can_t_separation" # 无法分离重叠的block
\ No newline at end of file
COLOR_BG_HEADER_TXT_BLOCK = "color_background_header_txt_block"
PAGE_NO = "page-no" # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = "page_no"
HEADER = "header"
FOOTER = "footer"
FOOTNOTE = "footnote"
NOT_IN_LAYOUT = "not_in_layout"
SPAN_OVERLAP = "span_overlap"
BLOCK_OVERLAP = "block_overlap"
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.libs.commons import fitz from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.hash_utils import compute_sha256 from magic_pdf.libs.hash_utils import compute_sha256
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: AbsReaderWriter): def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: DataWriter):
""" """从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 图片存放在save_path下,文件名是:
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。 {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
"""
# 拼接文件名 # 拼接文件名
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}" filename = f'{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}'
# 老版本返回不带bucket的路径 # 老版本返回不带bucket的路径
img_path = join_path(return_path, filename) if return_path is not None else None img_path = join_path(return_path, filename) if return_path is not None else None
# 新版本生成平铺路径 # 新版本生成平铺路径
img_hash256_path = f"{compute_sha256(img_path)}.jpg" img_hash256_path = f'{compute_sha256(img_path)}.jpg'
# 将坐标转换为fitz.Rect对象 # 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox) rect = fitz.Rect(*bbox)
...@@ -28,6 +26,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri ...@@ -28,6 +26,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
byte_data = pix.tobytes(output='jpeg', jpg_quality=95) byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
imageWriter.write(byte_data, img_hash256_path, AbsReaderWriter.MODE_BIN) imageWriter.write(img_hash256_path, byte_data)
return img_hash256_path return img_hash256_path
import enum import enum
import json import json
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
FileBasedDataWriter)
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou, bbox_relative_pos, box_area, calculate_iou,
...@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, ...@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
from magic_pdf.libs.commons import fitz, join_path from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6 CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
...@@ -1050,27 +1050,27 @@ class MagicModel: ...@@ -1050,27 +1050,27 @@ class MagicModel:
if __name__ == '__main__': if __name__ == '__main__':
drw = DiskReaderWriter(r'D:/project/20231108code-clean') drw = FileBasedDataReader(r'D:/project/20231108code-clean')
if 0: if 0:
pdf_file_path = r'linshixuqiu\19983-00.pdf' pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00_new.json' model_file_path = r'linshixuqiu\19983-00_new.json'
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN) pdf_bytes = drw.read(pdf_file_path)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT) model_json_txt = drw.read(model_file_path).decode()
model_list = json.loads(model_json_txt) model_list = json.loads(model_json_txt)
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00' write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs' img_bucket_path = 'imgs'
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path)) img_writer = FileBasedDataWriter(join_path(write_path, img_bucket_path))
pdf_docs = fitz.open('pdf', pdf_bytes) pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, pdf_docs)
if 1: if 1:
from magic_pdf.data.dataset import PymuDocDataset
model_list = json.loads( model_list = json.loads(
drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json') drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
) )
pdf_bytes = drw.read( pdf_bytes = drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf')
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
) magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
for i in range(7): for i in range(7):
print(magic_model.get_imgs(i)) print(magic_model.get_imgs(i))
import numpy as np # flake8: noqa
import torch
from loguru import logger
import os import os
import time import time
import cv2 import cv2
import numpy as np
import torch
import yaml import yaml
from loguru import logger
from PIL import Image from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -13,16 +15,18 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger ...@@ -13,16 +15,18 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try: try:
import torchtext import torchtext
if torchtext.__version__ >= "0.18.0": if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning() torchtext.disable_torchtext_deprecation_warning()
except ImportError: except ImportError:
pass pass
from magic_pdf.libs.Constants import * from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram from magic_pdf.model.sub_modules.model_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
class CustomPEKModel: class CustomPEKModel:
...@@ -41,42 +45,54 @@ class CustomPEKModel: ...@@ -41,42 +45,54 @@ class CustomPEKModel:
model_config_dir = os.path.join(root_dir, 'resources', 'model_config') model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径 # 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml') config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f: with open(config_path, 'r', encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader) self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置 # 初始化解析配置
# layout config # layout config
self.layout_config = kwargs.get("layout_config") self.layout_config = kwargs.get('layout_config')
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO) self.layout_model_name = self.layout_config.get(
'model', MODEL_NAME.DocLayout_YOLO
)
# formula config # formula config
self.formula_config = kwargs.get("formula_config") self.formula_config = kwargs.get('formula_config')
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD) self.mfd_model_name = self.formula_config.get(
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small) 'mfd_model', MODEL_NAME.YOLO_V8_MFD
self.apply_formula = self.formula_config.get("enable", True) )
self.mfr_model_name = self.formula_config.get(
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
)
self.apply_formula = self.formula_config.get('enable', True)
# table config # table config
self.table_config = kwargs.get("table_config") self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get("enable", False) self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE) self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
# ocr config # ocr config
self.apply_ocr = ocr self.apply_ocr = ocr
self.lang = kwargs.get("lang", None) self.lang = kwargs.get('lang', None)
logger.info( logger.info(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, " 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
"apply_table: {}, table_model: {}, lang: {}".format( 'apply_table: {}, table_model: {}, lang: {}'.format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.layout_model_name,
self.lang self.apply_formula,
self.apply_ocr,
self.apply_table,
self.table_model_name,
self.lang,
) )
) )
# 初始化解析方案 # 初始化解析方案
self.device = kwargs.get("device", "cpu") self.device = kwargs.get('device', 'cpu')
logger.info("using device: {}".format(self.device)) logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models")) models_dir = kwargs.get(
logger.info("using models_dir: {}".format(models_dir)) 'models_dir', os.path.join(root_dir, 'resources', 'models')
)
logger.info('using models_dir: {}'.format(models_dir))
atom_model_manager = AtomModelSingleton() atom_model_manager = AtomModelSingleton()
...@@ -85,18 +101,24 @@ class CustomPEKModel: ...@@ -85,18 +101,24 @@ class CustomPEKModel:
# 初始化公式检测模型 # 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model( self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD, atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])), mfd_weights=str(
device=self.device os.path.join(
models_dir, self.configs['weights'][self.mfd_model_name]
)
),
device=self.device,
) )
# 初始化公式解析模型 # 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name])) mfr_weight_dir = str(
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml")) os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
)
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model( self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path, mfr_cfg_path=mfr_cfg_path,
device=self.device device=self.device,
) )
# 初始化layout模型 # 初始化layout模型
...@@ -104,16 +126,28 @@ class CustomPEKModel: ...@@ -104,16 +126,28 @@ class CustomPEKModel:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.LAYOUTLMv3, layout_model_name=MODEL_NAME.LAYOUTLMv3,
layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])), layout_weights=str(
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), os.path.join(
device=self.device models_dir, self.configs['weights'][self.layout_model_name]
)
),
layout_config_file=str(
os.path.join(
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
)
),
device=self.device,
) )
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO, layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])), doclayout_yolo_weights=str(
device=self.device os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
device=self.device,
) )
# 初始化ocr # 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model( self.ocr_model = atom_model_manager.get_atom_model(
...@@ -124,19 +158,18 @@ class CustomPEKModel: ...@@ -124,19 +158,18 @@ class CustomPEKModel:
) )
# init table model # init table model
if self.apply_table: if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_name] table_model_dir = self.configs['weights'][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model( self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table, atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name, table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)), table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device device=self.device,
) )
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
page_start = time.time() page_start = time.time()
# layout检测 # layout检测
...@@ -149,7 +182,7 @@ class CustomPEKModel: ...@@ -149,7 +182,7 @@ class CustomPEKModel:
# doclayout_yolo # doclayout_yolo
layout_res = self.layout_model.predict(image) layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}") logger.info(f'layout detection time: {layout_cost}')
pil_img = Image.fromarray(image) pil_img = Image.fromarray(image)
...@@ -157,20 +190,22 @@ class CustomPEKModel: ...@@ -157,20 +190,22 @@ class CustomPEKModel:
# 公式检测 # 公式检测
mfd_start = time.time() mfd_start = time.time()
mfd_res = self.mfd_model.predict(image) mfd_res = self.mfd_model.predict(image)
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}") logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
# 公式识别 # 公式识别
mfr_start = time.time() mfr_start = time.time()
formula_list = self.mfr_model.predict(mfd_res, image) formula_list = self.mfr_model.predict(mfd_res, image)
layout_res.extend(formula_list) layout_res.extend(formula_list)
mfr_cost = round(time.time() - mfr_start, 2) mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}") logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存 # 清理显存
clean_vram(self.device, vram_threshold=8) clean_vram(self.device, vram_threshold=8)
# 从layout_res中获取ocr区域、表格区域、公式区域 # 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res) ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别 # ocr识别
ocr_start = time.time() ocr_start = time.time()
...@@ -206,27 +241,37 @@ class CustomPEKModel: ...@@ -206,27 +241,37 @@ class CustomPEKModel:
html_code = None html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad(): with torch.no_grad():
table_result = self.table_model.predict(new_image, "html") table_result = self.table_model.predict(new_image, 'html')
if len(table_result) > 0: if len(table_result) > 0:
html_code = table_result[0] html_code = table_result[0]
elif self.table_model_name == MODEL_NAME.TABLE_MASTER: elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE: elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image) html_code, table_cell_bboxes, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
if run_time > self.table_max_time: if run_time > self.table_max_time:
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s") logger.warning(
f'table recognition processing exceeds max time {self.table_max_time}s'
)
# 判断是否返回正常 # 判断是否返回正常
if html_code: if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>') expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending: if expected_ending:
res["html"] = html_code res['html'] = html_code
else: else:
logger.warning(f"table recognition processing fails, not found expected HTML table end") logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else: else:
logger.warning(f"table recognition processing fails, not get html return") logger.warning(
logger.info(f"table time: {round(time.time() - table_start, 2)}") 'table recognition processing fails, not get html return'
)
logger.info(f'table time: {round(time.time() - table_start, 2)}')
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----") logger.info(f'-----page total time: {round(time.time() - page_start, 2)}-----')
return layout_res return layout_res
from loguru import logger from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
RapidTableModel
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel StructTableModel
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
...@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): ...@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = { config = {
"model_dir": model_path, 'model_dir': model_path,
"device": _device_ 'device': _device_
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel() table_model = RapidTableModel()
else: else:
logger.error("table model type not allow") logger.error('table model type not allow')
exit(1) exit(1)
return table_model return table_model
...@@ -87,8 +92,8 @@ class AtomModelSingleton: ...@@ -87,8 +92,8 @@ class AtomModelSingleton:
return cls._instance return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs): def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get("lang", None) lang = kwargs.get('lang', None)
layout_model_name = kwargs.get("layout_model_name", None) layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang) key = (atom_model_name, layout_model_name, lang)
if key not in self._models: if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
...@@ -98,47 +103,47 @@ class AtomModelSingleton: ...@@ -98,47 +103,47 @@ class AtomModelSingleton:
def atom_model_init(model_name: str, **kwargs): def atom_model_init(model_name: str, **kwargs):
atom_model = None atom_model = None
if model_name == AtomicModel.Layout: if model_name == AtomicModel.Layout:
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3: if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init( atom_model = layout_model_init(
kwargs.get("layout_weights"), kwargs.get('layout_weights'),
kwargs.get("layout_config_file"), kwargs.get('layout_config_file'),
kwargs.get("device") kwargs.get('device')
) )
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO: elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init( atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"), kwargs.get('doclayout_yolo_weights'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.MFD: elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init( atom_model = mfd_model_init(
kwargs.get("mfd_weights"), kwargs.get('mfd_weights'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.MFR: elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init( atom_model = mfr_model_init(
kwargs.get("mfr_weight_dir"), kwargs.get('mfr_weight_dir'),
kwargs.get("mfr_cfg_path"), kwargs.get('mfr_cfg_path'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.OCR: elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init( atom_model = ocr_model_init(
kwargs.get("ocr_show_log"), kwargs.get('ocr_show_log'),
kwargs.get("det_db_box_thresh"), kwargs.get('det_db_box_thresh'),
kwargs.get("lang") kwargs.get('lang')
) )
elif model_name == AtomicModel.Table: elif model_name == AtomicModel.Table:
atom_model = table_model_init( atom_model = table_model_init(
kwargs.get("table_model_name"), kwargs.get('table_model_name'),
kwargs.get("table_model_path"), kwargs.get('table_model_path'),
kwargs.get("table_max_time"), kwargs.get('table_max_time'),
kwargs.get("device") kwargs.get('device')
) )
else: else:
logger.error("model name not allow") logger.error('model name not allow')
exit(1) exit(1)
if atom_model is None: if atom_model is None:
logger.error("model init failed") logger.error('model init failed')
exit(1) exit(1)
else: else:
return atom_model return atom_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment