Unverified Commit 6c8f5638 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1027 from icecraft/refactor/move_defs

refactor: move some constants or enums defs to config folder
parents bc992433 b492c19c
""" """span维度自定义字段."""
span维度自定义字段
"""
# span是否是跨页合并的 # span是否是跨页合并的
CROSS_PAGE = "cross_page" CROSS_PAGE = 'cross_page'
""" """
block维度自定义字段 block维度自定义字段
""" """
# block中lines是否被删除 # block中lines是否被删除
LINES_DELETED = "lines_deleted" LINES_DELETED = 'lines_deleted'
# table recognition max time default value # table recognition max time default value
TABLE_MAX_TIME_VALUE = 400 TABLE_MAX_TIME_VALUE = 400
...@@ -17,39 +15,39 @@ TABLE_MAX_TIME_VALUE = 400 ...@@ -17,39 +15,39 @@ TABLE_MAX_TIME_VALUE = 400
TABLE_MAX_LEN = 480 TABLE_MAX_LEN = 480
# table master structure dict # table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt" TABLE_MASTER_DICT = 'table_master_structure_dict.txt'
# table master dir # table master dir
TABLE_MASTER_DIR = "table_structure_tablemaster_infer/" TABLE_MASTER_DIR = 'table_structure_tablemaster_infer/'
# pp detect model dir # pp detect model dir
DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer" DETECT_MODEL_DIR = 'ch_PP-OCRv4_det_infer'
# pp rec model dir # pp rec model dir
REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer" REC_MODEL_DIR = 'ch_PP-OCRv4_rec_infer'
# pp rec char dict path # pp rec char dict path
REC_CHAR_DICT = "ppocr_keys_v1.txt" REC_CHAR_DICT = 'ppocr_keys_v1.txt'
# pp rec copy rec directory # pp rec copy rec directory
PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer" PP_REC_DIRECTORY = '.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer'
# pp rec copy det directory # pp rec copy det directory
PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer" PP_DET_DIRECTORY = '.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer'
class MODEL_NAME: class MODEL_NAME:
# pp table structure algorithm # pp table structure algorithm
TABLE_MASTER = "tablemaster" TABLE_MASTER = 'tablemaster'
# struct eqtable # struct eqtable
STRUCT_EQTABLE = "struct_eqtable" STRUCT_EQTABLE = 'struct_eqtable'
DocLayout_YOLO = "doclayout_yolo" DocLayout_YOLO = 'doclayout_yolo'
LAYOUTLMv3 = "layoutlmv3" LAYOUTLMv3 = 'layoutlmv3'
YOLO_V8_MFD = "yolo_v8_mfd" YOLO_V8_MFD = 'yolo_v8_mfd'
UniMerNet_v2_Small = "unimernet_small" UniMerNet_v2_Small = 'unimernet_small'
RAPID_TABLE = "rapid_table" RAPID_TABLE = 'rapid_table'
\ No newline at end of file
class DropReason:
TEXT_BLCOK_HOR_OVERLAP = 'text_block_horizontal_overlap' # 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP = (
'useful_block_horizontal_overlap' # 需保留的block水平覆盖
)
COMPLICATED_LAYOUT = 'complicated_layout' # 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS = 'too_many_layout_columns' # 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX = 'color_background_text_box' # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS = (
'high_computational_load_by_imgs' # 含特殊图片,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = (
'high_computational_load_by_svgs' # 特殊的SVG图,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = 'high_computational_load_by_total_pages' # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = 'missing doc_layout_result' # 版面分析失败
Exception = '_exception' # 解析中发生异常
ENCRYPTED = 'encrypted' # PDF是加密的
EMPTY_PDF = 'total_page=0' # PDF页面总数为0
NOT_IS_TEXT_PDF = 'not_is_text_pdf' # 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK = 'dense_single_line_block' # 无法清晰的分段
TITLE_DETECTION_FAILED = 'title_detection_failed' # 探测标题失败
TITLE_LEVEL_FAILED = (
'title_level_failed' # 分析标题级别失败(例如一级、二级、三级标题)
)
PARA_SPLIT_FAILED = 'para_split_failed' # 识别段落失败
PARA_MERGE_FAILED = 'para_merge_failed' # 段落合并失败
NOT_ALLOW_LANGUAGE = 'not_allow_language' # 不支持的语种
SPECIAL_PDF = 'special_pdf'
PSEUDO_SINGLE_COLUMN = 'pseudo_single_column' # 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT = 'can_not_detect_page_layout' # 无法分析页面的版面
NEGATIVE_BBOX_AREA = 'negative_bbox_area' # 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION = (
'overlap_blocks_can_t_separation' # 无法分离重叠的block
)
COLOR_BG_HEADER_TXT_BLOCK = 'color_background_header_txt_block'
PAGE_NO = 'page-no' # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = 'page_no'
HEADER = 'header'
FOOTER = 'footer'
FOOTNOTE = 'footnote'
NOT_IN_LAYOUT = 'not_in_layout'
SPAN_OVERLAP = 'span_overlap'
BLOCK_OVERLAP = 'block_overlap'
class MakeMode:
MM_MD = 'mm_markdown'
NLP_MD = 'nlp_markdown'
STANDARD_FORMAT = 'standard_format'
class DropMode:
WHOLE_PDF = 'whole_pdf'
SINGLE_PAGE = 'single_page'
NONE = 'none'
NONE_WITH_REASON = 'none_with_reason'
from enum import Enum from enum import Enum
class ModelBlockTypeEnum(Enum): class ModelBlockTypeEnum(Enum):
TITLE = 0 TITLE = 0
PLAIN_TEXT = 1 PLAIN_TEXT = 1
......
import math import math
from loguru import logger from loguru import logger
from magic_pdf.libs.boxbase import find_bottom_nearest_text_bbox, find_top_nearest_text_bbox from magic_pdf.config.ocr_content_type import ContentType
from magic_pdf.libs.boxbase import (find_bottom_nearest_text_bbox,
find_top_nearest_text_bbox)
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import join_path
from magic_pdf.libs.ocr_content_type import ContentType
TYPE_INLINE_EQUATION = ContentType.InlineEquation TYPE_INLINE_EQUATION = ContentType.InlineEquation
TYPE_INTERLINE_EQUATION = ContentType.InterlineEquation TYPE_INTERLINE_EQUATION = ContentType.InterlineEquation
...@@ -12,33 +14,30 @@ UNI_FORMAT_TEXT_TYPE = ['text', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6'] ...@@ -12,33 +14,30 @@ UNI_FORMAT_TEXT_TYPE = ['text', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']
@DeprecationWarning @DeprecationWarning
def mk_nlp_markdown_1(para_dict: dict): def mk_nlp_markdown_1(para_dict: dict):
""" """对排序后的bboxes拼接内容."""
对排序后的bboxes拼接内容
"""
content_lst = [] content_lst = []
for _, page_info in para_dict.items(): for _, page_info in para_dict.items():
para_blocks = page_info.get("para_blocks") para_blocks = page_info.get('para_blocks')
if not para_blocks: if not para_blocks:
continue continue
for block in para_blocks: for block in para_blocks:
item = block["paras"] item = block['paras']
for _, p in item.items(): for _, p in item.items():
para_text = p["para_text"] para_text = p['para_text']
is_title = p["is_para_title"] is_title = p['is_para_title']
title_level = p['para_title_level'] title_level = p['para_title_level']
md_title_prefix = "#"*title_level md_title_prefix = '#' * title_level
if is_title: if is_title:
content_lst.append(f"{md_title_prefix} {para_text}") content_lst.append(f'{md_title_prefix} {para_text}')
else: else:
content_lst.append(para_text) content_lst.append(para_text)
content_text = "\n\n".join(content_lst) content_text = '\n\n'.join(content_lst)
return content_text return content_text
# 找到目标字符串在段落中的索引 # 找到目标字符串在段落中的索引
def __find_index(paragraph, target): def __find_index(paragraph, target):
index = paragraph.find(target) index = paragraph.find(target)
...@@ -48,69 +47,76 @@ def __find_index(paragraph, target): ...@@ -48,69 +47,76 @@ def __find_index(paragraph, target):
return None return None
def __insert_string(paragraph, target, postion): def __insert_string(paragraph, target, position):
new_paragraph = paragraph[:postion] + target + paragraph[postion:] new_paragraph = paragraph[:position] + target + paragraph[position:]
return new_paragraph return new_paragraph
def __insert_after(content, image_content, target): def __insert_after(content, image_content, target):
""" """在content中找到target,将image_content插入到target后面."""
在content中找到target,将image_content插入到target后面
"""
index = content.find(target) index = content.find(target)
if index != -1: if index != -1:
content = content[:index+len(target)] + "\n\n" + image_content + "\n\n" + content[index+len(target):] content = (
content[: index + len(target)]
+ '\n\n'
+ image_content
+ '\n\n'
+ content[index + len(target) :]
)
else: else:
logger.error(f"Can't find the location of image {image_content} in the markdown file, search target is {target}") logger.error(
f"Can't find the location of image {image_content} in the markdown file, search target is {target}"
)
return content return content
def __insert_before(content, image_content, target): def __insert_before(content, image_content, target):
""" """在content中找到target,将image_content插入到target前面."""
在content中找到target,将image_content插入到target前面
"""
index = content.find(target) index = content.find(target)
if index != -1: if index != -1:
content = content[:index] + "\n\n" + image_content + "\n\n" + content[index:] content = content[:index] + '\n\n' + image_content + '\n\n' + content[index:]
else: else:
logger.error(f"Can't find the location of image {image_content} in the markdown file, search target is {target}") logger.error(
f"Can't find the location of image {image_content} in the markdown file, search target is {target}"
)
return content return content
@DeprecationWarning @DeprecationWarning
def mk_mm_markdown_1(para_dict: dict): def mk_mm_markdown_1(para_dict: dict):
"""拼装多模态markdown""" """拼装多模态markdown."""
content_lst = [] content_lst = []
for _, page_info in para_dict.items(): for _, page_info in para_dict.items():
page_lst = [] # 一个page内的段落列表 page_lst = [] # 一个page内的段落列表
para_blocks = page_info.get("para_blocks") para_blocks = page_info.get('para_blocks')
pymu_raw_blocks = page_info.get("preproc_blocks") pymu_raw_blocks = page_info.get('preproc_blocks')
all_page_images = [] all_page_images = []
all_page_images.extend(page_info.get("images",[])) all_page_images.extend(page_info.get('images', []))
all_page_images.extend(page_info.get("image_backup", []) ) all_page_images.extend(page_info.get('image_backup', []))
all_page_images.extend(page_info.get("tables",[])) all_page_images.extend(page_info.get('tables', []))
all_page_images.extend(page_info.get("table_backup",[]) ) all_page_images.extend(page_info.get('table_backup', []))
if not para_blocks or not pymu_raw_blocks: # 只有图片的拼接的场景 if not para_blocks or not pymu_raw_blocks: # 只有图片的拼接的场景
for img in all_page_images: for img in all_page_images:
page_lst.append(f"![]({img['image_path']})") # TODO 图片顺序 page_lst.append(f"![]({img['image_path']})") # TODO 图片顺序
page_md = "\n\n".join(page_lst) page_md = '\n\n'.join(page_lst)
else: else:
for block in para_blocks: for block in para_blocks:
item = block["paras"] item = block['paras']
for _, p in item.items(): for _, p in item.items():
para_text = p["para_text"] para_text = p['para_text']
is_title = p["is_para_title"] is_title = p['is_para_title']
title_level = p['para_title_level'] title_level = p['para_title_level']
md_title_prefix = "#"*title_level md_title_prefix = '#' * title_level
if is_title: if is_title:
page_lst.append(f"{md_title_prefix} {para_text}") page_lst.append(f'{md_title_prefix} {para_text}')
else: else:
page_lst.append(para_text) page_lst.append(para_text)
"""拼装成一个页面的文本""" """拼装成一个页面的文本"""
page_md = "\n\n".join(page_lst) page_md = '\n\n'.join(page_lst)
"""插入图片""" """插入图片"""
for img in all_page_images: for img in all_page_images:
imgbox = img['bbox'] imgbox = img['bbox']
...@@ -118,192 +124,215 @@ def mk_mm_markdown_1(para_dict: dict): ...@@ -118,192 +124,215 @@ def mk_mm_markdown_1(para_dict: dict):
# 先看在哪个block内 # 先看在哪个block内
for block in pymu_raw_blocks: for block in pymu_raw_blocks:
bbox = block['bbox'] bbox = block['bbox']
if bbox[0]-1 <= imgbox[0] < bbox[2]+1 and bbox[1]-1 <= imgbox[1] < bbox[3]+1:# 确定在block内 if (
for l in block['lines']: bbox[0] - 1 <= imgbox[0] < bbox[2] + 1
and bbox[1] - 1 <= imgbox[1] < bbox[3] + 1
): # 确定在block内
for l in block['lines']: # noqa: E741
line_box = l['bbox'] line_box = l['bbox']
if line_box[0]-1 <= imgbox[0] < line_box[2]+1 and line_box[1]-1 <= imgbox[1] < line_box[3]+1: # 在line内的,插入line前面 if (
line_txt = "".join([s['text'] for s in l['spans']]) line_box[0] - 1 <= imgbox[0] < line_box[2] + 1
page_md = __insert_before(page_md, img_content, line_txt) and line_box[1] - 1 <= imgbox[1] < line_box[3] + 1
): # 在line内的,插入line前面
line_txt = ''.join([s['text'] for s in l['spans']])
page_md = __insert_before(
page_md, img_content, line_txt
)
break break
break break
else:# 在行与行之间 else: # 在行与行之间
# 找到图片x0,y0与line的x0,y0最近的line # 找到图片x0,y0与line的x0,y0最近的line
min_distance = 100000 min_distance = 100000
min_line = None min_line = None
for l in block['lines']: for l in block['lines']: # noqa: E741
line_box = l['bbox'] line_box = l['bbox']
distance = math.sqrt((line_box[0] - imgbox[0])**2 + (line_box[1] - imgbox[1])**2) distance = math.sqrt(
(line_box[0] - imgbox[0]) ** 2
+ (line_box[1] - imgbox[1]) ** 2
)
if distance < min_distance: if distance < min_distance:
min_distance = distance min_distance = distance
min_line = l min_line = l
if min_line: if min_line:
line_txt = "".join([s['text'] for s in min_line['spans']]) line_txt = ''.join(
[s['text'] for s in min_line['spans']]
)
img_h = imgbox[3] - imgbox[1] img_h = imgbox[3] - imgbox[1]
if min_distance<img_h: # 文字在图片前面 if min_distance < img_h: # 文字在图片前面
page_md = __insert_after(page_md, img_content, line_txt) page_md = __insert_after(
page_md, img_content, line_txt
)
else: else:
page_md = __insert_before(page_md, img_content, line_txt) page_md = __insert_before(
page_md, img_content, line_txt
)
else: else:
logger.error(f"Can't find the location of image {img['image_path']} in the markdown file #1") logger.error(
else:# 应当在两个block之间 f"Can't find the location of image {img['image_path']} in the markdown file #1"
)
else: # 应当在两个block之间
# 找到上方最近的block,如果上方没有就找大下方最近的block # 找到上方最近的block,如果上方没有就找大下方最近的block
top_txt_block = find_top_nearest_text_bbox(pymu_raw_blocks, imgbox) top_txt_block = find_top_nearest_text_bbox(pymu_raw_blocks, imgbox)
if top_txt_block: if top_txt_block:
line_txt = "".join([s['text'] for s in top_txt_block['lines'][-1]['spans']]) line_txt = ''.join(
[s['text'] for s in top_txt_block['lines'][-1]['spans']]
)
page_md = __insert_after(page_md, img_content, line_txt) page_md = __insert_after(page_md, img_content, line_txt)
else: else:
bottom_txt_block = find_bottom_nearest_text_bbox(pymu_raw_blocks, imgbox) bottom_txt_block = find_bottom_nearest_text_bbox(
pymu_raw_blocks, imgbox
)
if bottom_txt_block: if bottom_txt_block:
line_txt = "".join([s['text'] for s in bottom_txt_block['lines'][0]['spans']]) line_txt = ''.join(
[
s['text']
for s in bottom_txt_block['lines'][0]['spans']
]
)
page_md = __insert_before(page_md, img_content, line_txt) page_md = __insert_before(page_md, img_content, line_txt)
else: else:
logger.error(f"Can't find the location of image {img['image_path']} in the markdown file #2") logger.error(
f"Can't find the location of image {img['image_path']} in the markdown file #2"
)
content_lst.append(page_md) content_lst.append(page_md)
"""拼装成全部页面的文本""" """拼装成全部页面的文本"""
content_text = "\n\n".join(content_lst) content_text = '\n\n'.join(content_lst)
return content_text return content_text
def __insert_after_para(text, type, element, content_list): def __insert_after_para(text, type, element, content_list):
""" """在content_list中找到text,将image_path作为一个新的node插入到text后面."""
在content_list中找到text,将image_path作为一个新的node插入到text后面
"""
for i, c in enumerate(content_list): for i, c in enumerate(content_list):
content_type = c.get("type") content_type = c.get('type')
if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get("text", ''): if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get('text', ''):
if type == "image": if type == 'image':
content_node = { content_node = {
"type": "image", 'type': 'image',
"img_path": element.get("image_path"), 'img_path': element.get('image_path'),
"img_alt": "", 'img_alt': '',
"img_title": "", 'img_title': '',
"img_caption": "", 'img_caption': '',
} }
elif type == "table": elif type == 'table':
content_node = { content_node = {
"type": "table", 'type': 'table',
"img_path": element.get("image_path"), 'img_path': element.get('image_path'),
"table_latex": element.get("text"), 'table_latex': element.get('text'),
"table_title": "", 'table_title': '',
"table_caption": "", 'table_caption': '',
"table_quality": element.get("quality"), 'table_quality': element.get('quality'),
} }
content_list.insert(i+1, content_node) content_list.insert(i + 1, content_node)
break break
else: else:
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}") logger.error(
f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}"
)
def __insert_before_para(text, type, element, content_list): def __insert_before_para(text, type, element, content_list):
""" """在content_list中找到text,将image_path作为一个新的node插入到text前面."""
在content_list中找到text,将image_path作为一个新的node插入到text前面
"""
for i, c in enumerate(content_list): for i, c in enumerate(content_list):
content_type = c.get("type") content_type = c.get('type')
if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get("text", ''): if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get('text', ''):
if type == "image": if type == 'image':
content_node = { content_node = {
"type": "image", 'type': 'image',
"img_path": element.get("image_path"), 'img_path': element.get('image_path'),
"img_alt": "", 'img_alt': '',
"img_title": "", 'img_title': '',
"img_caption": "", 'img_caption': '',
} }
elif type == "table": elif type == 'table':
content_node = { content_node = {
"type": "table", 'type': 'table',
"img_path": element.get("image_path"), 'img_path': element.get('image_path'),
"table_latex": element.get("text"), 'table_latex': element.get('text'),
"table_title": "", 'table_title': '',
"table_caption": "", 'table_caption': '',
"table_quality": element.get("quality"), 'table_quality': element.get('quality'),
} }
content_list.insert(i, content_node) content_list.insert(i, content_node)
break break
else: else:
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}") logger.error(
f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}"
)
def mk_universal_format(pdf_info_list: list, img_buket_path): def mk_universal_format(pdf_info_list: list, img_buket_path):
""" """构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY."""
构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY
"""
content_lst = [] content_lst = []
for page_info in pdf_info_list: for page_info in pdf_info_list:
page_lst = [] # 一个page内的段落列表 page_lst = [] # 一个page内的段落列表
para_blocks = page_info.get("para_blocks") para_blocks = page_info.get('para_blocks')
pymu_raw_blocks = page_info.get("preproc_blocks") pymu_raw_blocks = page_info.get('preproc_blocks')
all_page_images = [] all_page_images = []
all_page_images.extend(page_info.get("images",[])) all_page_images.extend(page_info.get('images', []))
all_page_images.extend(page_info.get("image_backup", []) ) all_page_images.extend(page_info.get('image_backup', []))
# all_page_images.extend(page_info.get("tables",[])) # all_page_images.extend(page_info.get("tables",[]))
# all_page_images.extend(page_info.get("table_backup",[]) ) # all_page_images.extend(page_info.get("table_backup",[]) )
all_page_tables = [] all_page_tables = []
all_page_tables.extend(page_info.get("tables", [])) all_page_tables.extend(page_info.get('tables', []))
if not para_blocks or not pymu_raw_blocks: # 只有图片的拼接的场景 if not para_blocks or not pymu_raw_blocks: # 只有图片的拼接的场景
for img in all_page_images: for img in all_page_images:
content_node = { content_node = {
"type": "image", 'type': 'image',
"img_path": join_path(img_buket_path, img['image_path']), 'img_path': join_path(img_buket_path, img['image_path']),
"img_alt":"", 'img_alt': '',
"img_title":"", 'img_title': '',
"img_caption":"" 'img_caption': '',
} }
page_lst.append(content_node) # TODO 图片顺序 page_lst.append(content_node) # TODO 图片顺序
for table in all_page_tables: for table in all_page_tables:
content_node = { content_node = {
"type": "table", 'type': 'table',
"img_path": join_path(img_buket_path, table['image_path']), 'img_path': join_path(img_buket_path, table['image_path']),
"table_latex": table.get("text"), 'table_latex': table.get('text'),
"table_title": "", 'table_title': '',
"table_caption": "", 'table_caption': '',
"table_quality": table.get("quality"), 'table_quality': table.get('quality'),
} }
page_lst.append(content_node) # TODO 图片顺序 page_lst.append(content_node) # TODO 图片顺序
else: else:
for block in para_blocks: for block in para_blocks:
item = block["paras"] item = block['paras']
for _, p in item.items(): for _, p in item.items():
font_type = p['para_font_type']# 对于文本来说,要么是普通文本,要么是个行间公式 font_type = p[
'para_font_type'
] # 对于文本来说,要么是普通文本,要么是个行间公式
if font_type == TYPE_INTERLINE_EQUATION: if font_type == TYPE_INTERLINE_EQUATION:
content_node = { content_node = {'type': 'equation', 'latex': p['para_text']}
"type": "equation",
"latex": p["para_text"]
}
page_lst.append(content_node) page_lst.append(content_node)
else: else:
para_text = p["para_text"] para_text = p['para_text']
is_title = p["is_para_title"] is_title = p['is_para_title']
title_level = p['para_title_level'] title_level = p['para_title_level']
if is_title: if is_title:
content_node = { content_node = {
"type": f"h{title_level}", 'type': f'h{title_level}',
"text": para_text 'text': para_text,
} }
page_lst.append(content_node) page_lst.append(content_node)
else: else:
content_node = { content_node = {'type': 'text', 'text': para_text}
"type": "text",
"text": para_text
}
page_lst.append(content_node) page_lst.append(content_node)
content_lst.extend(page_lst) content_lst.extend(page_lst)
"""插入图片""" """插入图片"""
for img in all_page_images: for img in all_page_images:
insert_img_or_table("image", img, pymu_raw_blocks, content_lst) insert_img_or_table('image', img, pymu_raw_blocks, content_lst)
"""插入表格""" """插入表格"""
for table in all_page_tables: for table in all_page_tables:
insert_img_or_table("table", table, pymu_raw_blocks, content_lst) insert_img_or_table('table', table, pymu_raw_blocks, content_lst)
# end for # end for
return content_lst return content_lst
...@@ -313,13 +342,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst): ...@@ -313,13 +342,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
# 先看在哪个block内 # 先看在哪个block内
for block in pymu_raw_blocks: for block in pymu_raw_blocks:
bbox = block['bbox'] bbox = block['bbox']
if bbox[0] - 1 <= element_bbox[0] < bbox[2] + 1 and bbox[1] - 1 <= element_bbox[1] < bbox[ if (
3] + 1: # 确定在这个大的block内,然后进入逐行比较距离 bbox[0] - 1 <= element_bbox[0] < bbox[2] + 1
for l in block['lines']: and bbox[1] - 1 <= element_bbox[1] < bbox[3] + 1
): # 确定在这个大的block内,然后进入逐行比较距离
for l in block['lines']: # noqa: E741
line_box = l['bbox'] line_box = l['bbox']
if line_box[0] - 1 <= element_bbox[0] < line_box[2] + 1 and line_box[1] - 1 <= element_bbox[1] < line_box[ if (
3] + 1: # 在line内的,插入line前面 line_box[0] - 1 <= element_bbox[0] < line_box[2] + 1
line_txt = "".join([s['text'] for s in l['spans']]) and line_box[1] - 1 <= element_bbox[1] < line_box[3] + 1
): # 在line内的,插入line前面
line_txt = ''.join([s['text'] for s in l['spans']])
__insert_before_para(line_txt, type, element, content_lst) __insert_before_para(line_txt, type, element, content_lst)
break break
break break
...@@ -327,14 +360,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst): ...@@ -327,14 +360,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
# 找到图片x0,y0与line的x0,y0最近的line # 找到图片x0,y0与line的x0,y0最近的line
min_distance = 100000 min_distance = 100000
min_line = None min_line = None
for l in block['lines']: for l in block['lines']: # noqa: E741
line_box = l['bbox'] line_box = l['bbox']
distance = math.sqrt((line_box[0] - element_bbox[0]) ** 2 + (line_box[1] - element_bbox[1]) ** 2) distance = math.sqrt(
(line_box[0] - element_bbox[0]) ** 2
+ (line_box[1] - element_bbox[1]) ** 2
)
if distance < min_distance: if distance < min_distance:
min_distance = distance min_distance = distance
min_line = l min_line = l
if min_line: if min_line:
line_txt = "".join([s['text'] for s in min_line['spans']]) line_txt = ''.join([s['text'] for s in min_line['spans']])
img_h = element_bbox[3] - element_bbox[1] img_h = element_bbox[3] - element_bbox[1]
if min_distance < img_h: # 文字在图片前面 if min_distance < img_h: # 文字在图片前面
__insert_after_para(line_txt, type, element, content_lst) __insert_after_para(line_txt, type, element, content_lst)
...@@ -342,56 +378,61 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst): ...@@ -342,56 +378,61 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
__insert_before_para(line_txt, type, element, content_lst) __insert_before_para(line_txt, type, element, content_lst)
break break
else: else:
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file #1") logger.error(
f"Can't find the location of image {element.get('image_path')} in the markdown file #1"
)
else: # 应当在两个block之间 else: # 应当在两个block之间
# 找到上方最近的block,如果上方没有就找大下方最近的block # 找到上方最近的block,如果上方没有就找大下方最近的block
top_txt_block = find_top_nearest_text_bbox(pymu_raw_blocks, element_bbox) top_txt_block = find_top_nearest_text_bbox(pymu_raw_blocks, element_bbox)
if top_txt_block: if top_txt_block:
line_txt = "".join([s['text'] for s in top_txt_block['lines'][-1]['spans']]) line_txt = ''.join([s['text'] for s in top_txt_block['lines'][-1]['spans']])
__insert_after_para(line_txt, type, element, content_lst) __insert_after_para(line_txt, type, element, content_lst)
else: else:
bottom_txt_block = find_bottom_nearest_text_bbox(pymu_raw_blocks, element_bbox) bottom_txt_block = find_bottom_nearest_text_bbox(
pymu_raw_blocks, element_bbox
)
if bottom_txt_block: if bottom_txt_block:
line_txt = "".join([s['text'] for s in bottom_txt_block['lines'][0]['spans']]) line_txt = ''.join(
[s['text'] for s in bottom_txt_block['lines'][0]['spans']]
)
__insert_before_para(line_txt, type, element, content_lst) __insert_before_para(line_txt, type, element, content_lst)
else: # TODO ,图片可能独占一列,这种情况上下是没有图片的 else: # TODO ,图片可能独占一列,这种情况上下是没有图片的
logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file #2") logger.error(
f"Can't find the location of image {element.get('image_path')} in the markdown file #2"
)
def mk_mm_markdown(content_list): def mk_mm_markdown(content_list):
""" """基于同一格式的内容列表,构造markdown,含图片."""
基于同一格式的内容列表,构造markdown,含图片
"""
content_md = [] content_md = []
for c in content_list: for c in content_list:
content_type = c.get("type") content_type = c.get('type')
if content_type == "text": if content_type == 'text':
content_md.append(c.get("text")) content_md.append(c.get('text'))
elif content_type == "equation": elif content_type == 'equation':
content = c.get("latex") content = c.get('latex')
if content.startswith("$$") and content.endswith("$$"): if content.startswith('$$') and content.endswith('$$'):
content_md.append(content) content_md.append(content)
else: else:
content_md.append(f"\n$$\n{c.get('latex')}\n$$\n") content_md.append(f"\n$$\n{c.get('latex')}\n$$\n")
elif content_type in UNI_FORMAT_TEXT_TYPE: elif content_type in UNI_FORMAT_TEXT_TYPE:
content_md.append(f"{'#'*int(content_type[1])} {c.get('text')}") content_md.append(f"{'#'*int(content_type[1])} {c.get('text')}")
elif content_type == "image": elif content_type == 'image':
content_md.append(f"![]({c.get('img_path')})") content_md.append(f"![]({c.get('img_path')})")
return "\n\n".join(content_md) return '\n\n'.join(content_md)
def mk_nlp_markdown(content_list): def mk_nlp_markdown(content_list):
""" """基于同一格式的内容列表,构造markdown,不含图片."""
基于同一格式的内容列表,构造markdown,不含图片
"""
content_md = [] content_md = []
for c in content_list: for c in content_list:
content_type = c.get("type") content_type = c.get('type')
if content_type == "text": if content_type == 'text':
content_md.append(c.get("text")) content_md.append(c.get('text'))
elif content_type == "equation": elif content_type == 'equation':
content_md.append(f"$$\n{c.get('latex')}\n$$") content_md.append(f"$$\n{c.get('latex')}\n$$")
elif content_type == "table": elif content_type == 'table':
content_md.append(f"$$$\n{c.get('table_latex')}\n$$$") content_md.append(f"$$$\n{c.get('table_latex')}\n$$$")
elif content_type in UNI_FORMAT_TEXT_TYPE: elif content_type in UNI_FORMAT_TEXT_TYPE:
content_md.append(f"{'#'*int(content_type[1])} {c.get('text')}") content_md.append(f"{'#'*int(content_type[1])} {c.get('text')}")
return "\n\n".join(content_md) return '\n\n'.join(content_md)
\ No newline at end of file
...@@ -2,17 +2,16 @@ import re ...@@ -2,17 +2,16 @@ import re
from loguru import logger from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import join_path
from magic_pdf.libs.language import detect_lang from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.para.para_split_v3 import ListLineTag from magic_pdf.para.para_split_v3 import ListLineTag
def __is_hyphen_at_line_end(line): def __is_hyphen_at_line_end(line):
""" """Check if a line ends with one or more letters followed by a hyphen.
Check if a line ends with one or more letters followed by a hyphen.
Args: Args:
line (str): The line of text to check. line (str): The line of text to check.
...@@ -162,7 +161,7 @@ def merge_para_with_text(para_block): ...@@ -162,7 +161,7 @@ def merge_para_with_text(para_block):
if span_type in [ContentType.Text, ContentType.InterlineEquation]: if span_type in [ContentType.Text, ContentType.InterlineEquation]:
para_text += content # 中文/日语/韩文语境下,content间不需要空格分隔 para_text += content # 中文/日语/韩文语境下,content间不需要空格分隔
elif span_type == ContentType.InlineEquation: elif span_type == ContentType.InlineEquation:
para_text += f" {content} " para_text += f' {content} '
else: else:
if span_type in [ContentType.Text, ContentType.InlineEquation]: if span_type in [ContentType.Text, ContentType.InlineEquation]:
# 如果是前一行带有-连字符,那么末尾不应该加空格 # 如果是前一行带有-连字符,那么末尾不应该加空格
...@@ -171,7 +170,7 @@ def merge_para_with_text(para_block): ...@@ -171,7 +170,7 @@ def merge_para_with_text(para_block):
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit(): elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
para_text += content para_text += content
else: # 西方文本语境下 content间需要空格分隔 else: # 西方文本语境下 content间需要空格分隔
para_text += f"{content} " para_text += f'{content} '
elif span_type == ContentType.InterlineEquation: elif span_type == ContentType.InterlineEquation:
para_text += content para_text += content
else: else:
......
""" """输入: s3路径,每行一个 输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置."""
输入: s3路径,每行一个
输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置
"""
import sys import sys
import click from collections import Counter
from magic_pdf.libs.commons import read_file, mymax, get_top_percent_list import click
from magic_pdf.libs.commons import fitz
from loguru import logger from loguru import logger
from collections import Counter
from magic_pdf.libs.drop_reason import DropReason from magic_pdf.config.drop_reason import DropReason
from magic_pdf.libs.commons import fitz, get_top_percent_list, mymax, read_file
from magic_pdf.libs.language import detect_lang from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.pdf_check import detect_invalid_chars from magic_pdf.libs.pdf_check import detect_invalid_chars
...@@ -19,8 +16,10 @@ junk_limit_min = 10 ...@@ -19,8 +16,10 @@ junk_limit_min = 10
def calculate_max_image_area_per_page(result: list, page_width_pts, page_height_pts): def calculate_max_image_area_per_page(result: list, page_width_pts, page_height_pts):
max_image_area_per_page = [mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz]) for page_img_sz in max_image_area_per_page = [
result] mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz])
for page_img_sz in result
]
page_area = int(page_width_pts) * int(page_height_pts) page_area = int(page_width_pts) * int(page_height_pts)
max_image_area_per_page = [area / page_area for area in max_image_area_per_page] max_image_area_per_page = [area / page_area for area in max_image_area_per_page]
max_image_area_per_page = [area for area in max_image_area_per_page if area > 0.6] max_image_area_per_page = [area for area in max_image_area_per_page if area > 0.6]
...@@ -33,7 +32,9 @@ def process_image(page, junk_img_bojids=[]): ...@@ -33,7 +32,9 @@ def process_image(page, junk_img_bojids=[]):
dedup = set() dedup = set()
for img in items: for img in items:
# 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是 # 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
img_bojid = img[0] # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等 img_bojid = img[
0
] # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
if img_bojid in junk_img_bojids: # 如果是垃圾图像,就跳过 if img_bojid in junk_img_bojids: # 如果是垃圾图像,就跳过
continue continue
recs = page.get_image_rects(img, transform=True) recs = page.get_image_rects(img, transform=True)
...@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]): ...@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]):
x0, y0, x1, y1 = map(int, rec) x0, y0, x1, y1 = map(int, rec)
width = x1 - x0 width = x1 - x0
height = y1 - y0 height = y1 - y0
if (x0, y0, x1, y1, img_bojid) in dedup: # 这里面会出现一些重复的bbox,无需重复出现,需要去掉 if (
x0,
y0,
x1,
y1,
img_bojid,
) in dedup: # 这里面会出现一些重复的bbox,无需重复出现,需要去掉
continue continue
if not all([width, height]): # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义 if not all(
[width, height]
): # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
continue continue
dedup.add((x0, y0, x1, y1, img_bojid)) dedup.add((x0, y0, x1, y1, img_bojid))
page_result.append([x0, y0, x1, y1, img_bojid]) page_result.append([x0, y0, x1, y1, img_bojid])
...@@ -52,8 +61,8 @@ def process_image(page, junk_img_bojids=[]): ...@@ -52,8 +61,8 @@ def process_image(page, junk_img_bojids=[]):
def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
""" """返回每个页面里的图片的四元组,每个页面多个图片。
返回每个页面里的图片的四元组,每个页面多个图片。
:param doc: :param doc:
:return: :return:
""" """
...@@ -63,13 +72,17 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: ...@@ -63,13 +72,17 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
junk_limit = max(len(doc) * 0.5, junk_limit_min) # 对一些页数比较少的进行豁免 junk_limit = max(len(doc) * 0.5, junk_limit_min) # 对一些页数比较少的进行豁免
junk_img_bojids = [img_bojid for img_bojid, count in img_bojid_counter.items() if count >= junk_limit] junk_img_bojids = [
img_bojid
#todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多 for img_bojid, count in img_bojid_counter.items()
#有两种扫描版,一种文字版,这里可能会有误判 if count >= junk_limit
#扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张 ]
#扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
#文字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist # todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
# 有两种扫描版,一种文字版,这里可能会有误判
# 扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
# 扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
# 文 字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
imgs_len_list = [len(page.get_images()) for page in doc] imgs_len_list = [len(page.get_images()) for page in doc]
special_limit_pages = 10 special_limit_pages = 10
...@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: ...@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
break break
if i >= special_limit_pages: if i >= special_limit_pages:
break break
page_result = process_image(page) # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析 page_result = process_image(
page
) # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
result.append(page_result) result.append(page_result)
for item in result: for item in result:
if not any(item): # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版 if not any(
if max(imgs_len_list) == min(imgs_len_list) and max( item
imgs_len_list) >= junk_limit_min: # 如果是特殊文字版,就把junklist置空并break ): # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if (
max(imgs_len_list) == min(imgs_len_list)
and max(imgs_len_list) >= junk_limit_min
): # 如果是特殊文字版,就把junklist置空并break
junk_img_bojids = [] junk_img_bojids = []
else: # 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist else: # 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist
pass pass
...@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: ...@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
top_eighty_percent = get_top_percent_list(imgs_len_list, 0.8) top_eighty_percent = get_top_percent_list(imgs_len_list, 0.8)
# 检查前80%的元素是否都相等 # 检查前80%的元素是否都相等
if len(set(top_eighty_percent)) == 1 and max(imgs_len_list) >= junk_limit_min: if len(set(top_eighty_percent)) == 1 and max(imgs_len_list) >= junk_limit_min:
# # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist # # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist
# if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min: # if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min:
#前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist # 前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
max_image_area_per_page = calculate_max_image_area_per_page(result, page_width_pts, page_height_pts) max_image_area_per_page = calculate_max_image_area_per_page(
if len(max_image_area_per_page) < 0.8 * special_limit_pages: # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空 result, page_width_pts, page_height_pts
)
if (
len(max_image_area_per_page) < 0.8 * special_limit_pages
): # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
junk_img_bojids = [] junk_img_bojids = []
else: # 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist else: # 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist
pass pass
else: # 每页图片数量不一致,需要清掉junklist全量跑前50页图片 else: # 每页图片数量不一致,需要清掉junklist全量跑前50页图片
junk_img_bojids = [] junk_img_bojids = []
#正式进入取前50页图片的信息流程 # 正式进入取前50页图片的信息流程
result = [] result = []
for i, page in enumerate(doc): for i, page in enumerate(doc):
if i >= scan_max_page: if i >= scan_max_page:
...@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list: ...@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
def get_pdf_page_size_pts(doc: fitz.Document): def get_pdf_page_size_pts(doc: fitz.Document):
page_cnt = len(doc) page_cnt = len(doc)
l: int = min(page_cnt, 50) l: int = min(page_cnt, 50)
#把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了) # 把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
page_width_list = [] page_width_list = []
page_height_list = [] page_height_list = []
for i in range(l): for i in range(l):
...@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document): ...@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
# 拿所有text的blocks # 拿所有text的blocks
# text_block = page.get_text("words") # text_block = page.get_text("words")
# text_block_len = sum([len(t[4]) for t in text_block]) # text_block_len = sum([len(t[4]) for t in text_block])
#拿所有text的str # 拿所有text的str
text_block = page.get_text("text") text_block = page.get_text('text')
text_block_len = len(text_block) text_block_len = len(text_block)
# logger.info(f"page {page.number} text_block_len: {text_block_len}") # logger.info(f"page {page.number} text_block_len: {text_block_len}")
text_len_lst.append(text_block_len) text_len_lst.append(text_block_len)
...@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document): ...@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
def get_pdf_text_layout_per_page(doc: fitz.Document): def get_pdf_text_layout_per_page(doc: fitz.Document):
""" """根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
Args: Args:
doc (fitz.Document): PDF文档对象。 doc (fitz.Document): PDF文档对象。
Returns: Returns:
List[str]: 每一页的文本布局(横向、纵向、未知)。 List[str]: 每一页的文本布局(横向、纵向、未知)。
""" """
text_layout_list = [] text_layout_list = []
...@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document): ...@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# 创建每一页的纵向和横向的文本行数计数器 # 创建每一页的纵向和横向的文本行数计数器
vertical_count = 0 vertical_count = 0
horizontal_count = 0 horizontal_count = 0
text_dict = page.get_text("dict") text_dict = page.get_text('dict')
if "blocks" in text_dict: if 'blocks' in text_dict:
for block in text_dict["blocks"]: for block in text_dict['blocks']:
if 'lines' in block: if 'lines' in block:
for line in block["lines"]: for line in block['lines']:
# 获取line的bbox顶点坐标 # 获取line的bbox顶点坐标
x0, y0, x1, y1 = line['bbox'] x0, y0, x1, y1 = line['bbox']
# 计算bbox的宽高 # 计算bbox的宽高
...@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document): ...@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
if len(font_sizes) > 0: if len(font_sizes) > 0:
average_font_size = sum(font_sizes) / len(font_sizes) average_font_size = sum(font_sizes) / len(font_sizes)
else: else:
average_font_size = 10 # 有的line拿不到font_size,先定一个阈值100 average_font_size = (
if area <= average_font_size ** 2: # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向 10 # 有的line拿不到font_size,先定一个阈值100
)
if (
area <= average_font_size**2
): # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
continue continue
else: else:
if 'wmode' in line: # 通过wmode判断文本方向 if 'wmode' in line: # 通过wmode判断文本方向
...@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document): ...@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}") # print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
# 判断每一页的文本布局 # 判断每一页的文本布局
if vertical_count == 0 and horizontal_count == 0: # 该页没有文本,无法判断 if vertical_count == 0 and horizontal_count == 0: # 该页没有文本,无法判断
text_layout_list.append("unknow") text_layout_list.append('unknow')
continue continue
else: else:
if vertical_count > horizontal_count: # 该页的文本纵向行数大于横向的 if vertical_count > horizontal_count: # 该页的文本纵向行数大于横向的
text_layout_list.append("vertical") text_layout_list.append('vertical')
else: # 该页的文本横向行数大于纵向的 else: # 该页的文本横向行数大于纵向的
text_layout_list.append("horizontal") text_layout_list.append('horizontal')
# logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}") # logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
return text_layout_list return text_layout_list
'''定义一个自定义异常用来抛出单页svg太多的pdf''' """定义一个自定义异常用来抛出单页svg太多的pdf"""
class PageSvgsTooManyError(Exception): class PageSvgsTooManyError(Exception):
def __init__(self, message="Page SVGs are too many"): def __init__(self, message='Page SVGs are too many'):
self.message = message self.message = message
super().__init__(self.message) super().__init__(self.message)
...@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document): ...@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document):
if page_id >= scan_max_page: if page_id >= scan_max_page:
break break
# 拿所有text的str # 拿所有text的str
text_block = page.get_text("text") text_block = page.get_text('text')
page_language = detect_lang(text_block) page_language = detect_lang(text_block)
language_lst.append(page_language) language_lst.append(page_language)
...@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document): ...@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document):
def check_invalid_chars(pdf_bytes): def check_invalid_chars(pdf_bytes):
""" """乱码检测."""
乱码检测
"""
return detect_invalid_chars(pdf_bytes) return detect_invalid_chars(pdf_bytes)
...@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes):
:param pdf_bytes: pdf文件的二进制数据 :param pdf_bytes: pdf文件的二进制数据
几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取 几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取
""" """
doc = fitz.open("pdf", pdf_bytes) doc = fitz.open('pdf', pdf_bytes)
is_needs_password = doc.needs_pass is_needs_password = doc.needs_pass
is_encrypted = doc.is_encrypted is_encrypted = doc.is_encrypted
total_page = len(doc) total_page = len(doc)
if total_page == 0: if total_page == 0:
logger.warning(f"drop this pdf, drop_reason: {DropReason.EMPTY_PDF}") logger.warning(f'drop this pdf, drop_reason: {DropReason.EMPTY_PDF}')
result = {"_need_drop": True, "_drop_reason": DropReason.EMPTY_PDF} result = {'_need_drop': True, '_drop_reason': DropReason.EMPTY_PDF}
return result return result
else: else:
page_width_pts, page_height_pts = get_pdf_page_size_pts(doc) page_width_pts, page_height_pts = get_pdf_page_size_pts(doc)
...@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
imgs_per_page = get_imgs_per_page(doc) imgs_per_page = get_imgs_per_page(doc)
# logger.info(f"imgs_per_page: {imgs_per_page}") # logger.info(f"imgs_per_page: {imgs_per_page}")
image_info_per_page, junk_img_bojids = get_image_info(doc, page_width_pts, page_height_pts) image_info_per_page, junk_img_bojids = get_image_info(
doc, page_width_pts, page_height_pts
)
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}") # logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
text_len_per_page = get_pdf_textlen_per_page(doc) text_len_per_page = get_pdf_textlen_per_page(doc)
# logger.info(f"text_len_per_page: {text_len_per_page}") # logger.info(f"text_len_per_page: {text_len_per_page}")
...@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes):
# 最后输出一条json # 最后输出一条json
res = { res = {
"is_needs_password": is_needs_password, 'is_needs_password': is_needs_password,
"is_encrypted": is_encrypted, 'is_encrypted': is_encrypted,
"total_page": total_page, 'total_page': total_page,
"page_width_pts": int(page_width_pts), 'page_width_pts': int(page_width_pts),
"page_height_pts": int(page_height_pts), 'page_height_pts': int(page_height_pts),
"image_info_per_page": image_info_per_page, 'image_info_per_page': image_info_per_page,
"text_len_per_page": text_len_per_page, 'text_len_per_page': text_len_per_page,
"text_layout_per_page": text_layout_per_page, 'text_layout_per_page': text_layout_per_page,
"text_language": text_language, 'text_language': text_language,
# "svgs_per_page": svgs_per_page, # "svgs_per_page": svgs_per_page,
"imgs_per_page": imgs_per_page, # 增加每页img数量list 'imgs_per_page': imgs_per_page, # 增加每页img数量list
"junk_img_bojids": junk_img_bojids, # 增加垃圾图片的bojid list 'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
"invalid_chars": invalid_chars, 'invalid_chars': invalid_chars,
"metadata": doc.metadata 'metadata': doc.metadata,
} }
# logger.info(json.dumps(res, ensure_ascii=False)) # logger.info(json.dumps(res, ensure_ascii=False))
return res return res
...@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes):
@click.option('--s3-pdf-path', help='s3上pdf文件的路径') @click.option('--s3-pdf-path', help='s3上pdf文件的路径')
@click.option('--s3-profile', help='s3上的profile') @click.option('--s3-profile', help='s3上的profile')
def main(s3_pdf_path: str, s3_profile: str): def main(s3_pdf_path: str, s3_profile: str):
""" """"""
"""
try: try:
file_content = read_file(s3_pdf_path, s3_profile) file_content = read_file(s3_pdf_path, s3_profile)
pdf_meta_scan(file_content) pdf_meta_scan(file_content)
except Exception as e: except Exception as e:
print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr) print(f'ERROR: {s3_pdf_path}, {e}', file=sys.stderr)
logger.exception(e) logger.exception(e)
...@@ -381,7 +403,7 @@ if __name__ == '__main__': ...@@ -381,7 +403,7 @@ if __name__ == '__main__':
# "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf" # "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf" # "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf"
# "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf" # "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf"
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","") # file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","") # noqa: E501
# file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","") # file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","")
# doc = fitz.open("pdf", file_content) # doc = fitz.open("pdf", file_content)
# text_layout_lst = get_pdf_text_layout_per_page(doc) # text_layout_lst = get_pdf_text_layout_per_page(doc)
......
...@@ -5,13 +5,13 @@ from pathlib import Path ...@@ -5,13 +5,13 @@ from pathlib import Path
from loguru import logger from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.data_reader_writer import FileBasedDataReader from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
from magic_pdf.integrations.rag.type import (CategoryType, ContentObject, from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
ElementRelation, ElementRelType, ElementRelation, ElementRelType,
LayoutElements, LayoutElements,
LayoutElementsExtra, PageInfo) LayoutElementsExtra, PageInfo)
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.tools.common import do_parse, prepare_env from magic_pdf.tools.common import do_parse, prepare_env
......
class MakeMode:
MM_MD = "mm_markdown"
NLP_MD = "nlp_markdown"
STANDARD_FORMAT = "standard_format"
class DropMode:
WHOLE_PDF = "whole_pdf"
SINGLE_PAGE = "single_page"
NONE = "none"
NONE_WITH_REASON = "none_with_reason"
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
from loguru import logger from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量 # 定义配置文件名常量
...@@ -99,7 +99,7 @@ def get_table_recog_config(): ...@@ -99,7 +99,7 @@ def get_table_recog_config():
def get_layout_config(): def get_layout_config():
config = read_config() config = read_config()
layout_config = config.get("layout-config") layout_config = config.get('layout-config')
if layout_config is None: if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default") logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}') return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
...@@ -109,7 +109,7 @@ def get_layout_config(): ...@@ -109,7 +109,7 @@ def get_layout_config():
def get_formula_config(): def get_formula_config():
config = read_config() config = read_config()
formula_config = config.get("formula-config") formula_config = config.get('formula-config')
if formula_config is None: if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default") logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}') return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
...@@ -117,5 +117,5 @@ def get_formula_config(): ...@@ -117,5 +117,5 @@ def get_formula_config():
return formula_config return formula_config
if __name__ == "__main__": if __name__ == '__main__':
ak, sk, endpoint = get_s3_config("llm-raw") ak, sk, endpoint = get_s3_config('llm-raw')
from magic_pdf.config.constants import CROSS_PAGE
from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
ContentType)
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.libs.commons import fitz # PyMuPDF from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.libs.Constants import CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
......
class DropReason:
TEXT_BLCOK_HOR_OVERLAP = "text_block_horizontal_overlap" # 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP = "useful_block_horizontal_overlap" # 需保留的block水平覆盖
COMPLICATED_LAYOUT = "complicated_layout" # 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS = "too_many_layout_columns" # 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX = "color_background_text_box" # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS = "high_computational_load_by_imgs" # 含特殊图片,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = "high_computational_load_by_svgs" # 特殊的SVG图,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = "high_computational_load_by_total_pages" # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = "missing doc_layout_result" # 版面分析失败
Exception = "_exception" # 解析中发生异常
ENCRYPTED = "encrypted" # PDF是加密的
EMPTY_PDF = "total_page=0" # PDF页面总数为0
NOT_IS_TEXT_PDF = "not_is_text_pdf" # 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK = "dense_single_line_block" # 无法清晰的分段
TITLE_DETECTION_FAILED = "title_detection_failed" # 探测标题失败
TITLE_LEVEL_FAILED = "title_level_failed" # 分析标题级别失败(例如一级、二级、三级标题)
PARA_SPLIT_FAILED = "para_split_failed" # 识别段落失败
PARA_MERGE_FAILED = "para_merge_failed" # 段落合并失败
NOT_ALLOW_LANGUAGE = "not_allow_language" # 不支持的语种
SPECIAL_PDF = "special_pdf"
PSEUDO_SINGLE_COLUMN = "pseudo_single_column" # 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT="can_not_detect_page_layout" # 无法分析页面的版面
NEGATIVE_BBOX_AREA = "negative_bbox_area" # 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION = "overlap_blocks_can_t_separation" # 无法分离重叠的block
\ No newline at end of file
COLOR_BG_HEADER_TXT_BLOCK = "color_background_header_txt_block"
PAGE_NO = "page-no" # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = "page_no"
HEADER = "header"
FOOTER = "footer"
FOOTNOTE = "footnote"
NOT_IN_LAYOUT = "not_in_layout"
SPAN_OVERLAP = "span_overlap"
BLOCK_OVERLAP = "block_overlap"
import enum import enum
import json import json
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.data_reader_writer import (FileBasedDataReader, from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
FileBasedDataWriter) FileBasedDataWriter)
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
...@@ -11,8 +13,6 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, ...@@ -11,8 +13,6 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
from magic_pdf.libs.commons import fitz, join_path from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO = 0.6 CAPATION_OVERLAP_AREA_RATIO = 0.6
......
import numpy as np # flake8: noqa
import torch
from loguru import logger
import os import os
import time import time
import cv2 import cv2
import numpy as np
import torch
import yaml import yaml
from loguru import logger
from PIL import Image from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -13,20 +15,21 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger ...@@ -13,20 +15,21 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try: try:
import torchtext import torchtext
if torchtext.__version__ >= "0.18.0": if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning() torchtext.disable_torchtext_deprecation_warning()
except ImportError: except ImportError:
pass pass
from magic_pdf.libs.Constants import * from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram from magic_pdf.model.sub_modules.model_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
class CustomPEKModel: class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
""" """
======== model init ======== ======== model init ========
...@@ -41,42 +44,54 @@ class CustomPEKModel: ...@@ -41,42 +44,54 @@ class CustomPEKModel:
model_config_dir = os.path.join(root_dir, 'resources', 'model_config') model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径 # 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml') config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f: with open(config_path, 'r', encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader) self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置 # 初始化解析配置
# layout config # layout config
self.layout_config = kwargs.get("layout_config") self.layout_config = kwargs.get('layout_config')
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO) self.layout_model_name = self.layout_config.get(
'model', MODEL_NAME.DocLayout_YOLO
)
# formula config # formula config
self.formula_config = kwargs.get("formula_config") self.formula_config = kwargs.get('formula_config')
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD) self.mfd_model_name = self.formula_config.get(
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small) 'mfd_model', MODEL_NAME.YOLO_V8_MFD
self.apply_formula = self.formula_config.get("enable", True) )
self.mfr_model_name = self.formula_config.get(
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
)
self.apply_formula = self.formula_config.get('enable', True)
# table config # table config
self.table_config = kwargs.get("table_config") self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get("enable", False) self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE) self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
# ocr config # ocr config
self.apply_ocr = ocr self.apply_ocr = ocr
self.lang = kwargs.get("lang", None) self.lang = kwargs.get('lang', None)
logger.info( logger.info(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, " 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
"apply_table: {}, table_model: {}, lang: {}".format( 'apply_table: {}, table_model: {}, lang: {}'.format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.layout_model_name,
self.lang self.apply_formula,
self.apply_ocr,
self.apply_table,
self.table_model_name,
self.lang,
) )
) )
# 初始化解析方案 # 初始化解析方案
self.device = kwargs.get("device", "cpu") self.device = kwargs.get('device', 'cpu')
logger.info("using device: {}".format(self.device)) logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models")) models_dir = kwargs.get(
logger.info("using models_dir: {}".format(models_dir)) 'models_dir', os.path.join(root_dir, 'resources', 'models')
)
logger.info('using models_dir: {}'.format(models_dir))
atom_model_manager = AtomModelSingleton() atom_model_manager = AtomModelSingleton()
...@@ -85,18 +100,24 @@ class CustomPEKModel: ...@@ -85,18 +100,24 @@ class CustomPEKModel:
# 初始化公式检测模型 # 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model( self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD, atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])), mfd_weights=str(
device=self.device os.path.join(
models_dir, self.configs['weights'][self.mfd_model_name]
)
),
device=self.device,
) )
# 初始化公式解析模型 # 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name])) mfr_weight_dir = str(
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml")) os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
)
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model( self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path, mfr_cfg_path=mfr_cfg_path,
device=self.device device=self.device,
) )
# 初始化layout模型 # 初始化layout模型
...@@ -104,16 +125,28 @@ class CustomPEKModel: ...@@ -104,16 +125,28 @@ class CustomPEKModel:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.LAYOUTLMv3, layout_model_name=MODEL_NAME.LAYOUTLMv3,
layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])), layout_weights=str(
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), os.path.join(
device=self.device models_dir, self.configs['weights'][self.layout_model_name]
)
),
layout_config_file=str(
os.path.join(
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
)
),
device=self.device,
) )
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO, layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])), doclayout_yolo_weights=str(
device=self.device os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
device=self.device,
) )
# 初始化ocr # 初始化ocr
if self.apply_ocr: if self.apply_ocr:
...@@ -121,23 +154,22 @@ class CustomPEKModel: ...@@ -121,23 +154,22 @@ class CustomPEKModel:
atom_model_name=AtomicModel.OCR, atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log, ocr_show_log=show_log,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=self.lang lang=self.lang,
) )
# init table model # init table model
if self.apply_table: if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_name] table_model_dir = self.configs['weights'][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model( self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table, atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name, table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)), table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device device=self.device,
) )
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
page_start = time.time() page_start = time.time()
# layout检测 # layout检测
...@@ -150,7 +182,7 @@ class CustomPEKModel: ...@@ -150,7 +182,7 @@ class CustomPEKModel:
# doclayout_yolo # doclayout_yolo
layout_res = self.layout_model.predict(image) layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}") logger.info(f'layout detection time: {layout_cost}')
pil_img = Image.fromarray(image) pil_img = Image.fromarray(image)
...@@ -158,32 +190,40 @@ class CustomPEKModel: ...@@ -158,32 +190,40 @@ class CustomPEKModel:
# 公式检测 # 公式检测
mfd_start = time.time() mfd_start = time.time()
mfd_res = self.mfd_model.predict(image) mfd_res = self.mfd_model.predict(image)
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}") logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
# 公式识别 # 公式识别
mfr_start = time.time() mfr_start = time.time()
formula_list = self.mfr_model.predict(mfd_res, image) formula_list = self.mfr_model.predict(mfd_res, image)
layout_res.extend(formula_list) layout_res.extend(formula_list)
mfr_cost = round(time.time() - mfr_start, 2) mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}") logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存 # 清理显存
clean_vram(self.device, vram_threshold=8) clean_vram(self.device, vram_threshold=8)
# 从layout_res中获取ocr区域、表格区域、公式区域 # 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res) ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别 # ocr识别
if self.apply_ocr: if self.apply_ocr:
ocr_start = time.time() ocr_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50) new_image, useful_list = crop_img(
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list) res, pil_img, crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list
)
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[
0
]
# Integration results # Integration results
if ocr_res: if ocr_res:
...@@ -191,7 +231,7 @@ class CustomPEKModel: ...@@ -191,7 +231,7 @@ class CustomPEKModel:
layout_res.extend(ocr_result_list) layout_res.extend(ocr_result_list)
ocr_cost = round(time.time() - ocr_start, 2) ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr time: {ocr_cost}") logger.info(f'ocr time: {ocr_cost}')
# 表格识别 table recognition # 表格识别 table recognition
if self.apply_table: if self.apply_table:
...@@ -202,27 +242,37 @@ class CustomPEKModel: ...@@ -202,27 +242,37 @@ class CustomPEKModel:
html_code = None html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad(): with torch.no_grad():
table_result = self.table_model.predict(new_image, "html") table_result = self.table_model.predict(new_image, 'html')
if len(table_result) > 0: if len(table_result) > 0:
html_code = table_result[0] html_code = table_result[0]
elif self.table_model_name == MODEL_NAME.TABLE_MASTER: elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE: elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image) html_code, table_cell_bboxes, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
if run_time > self.table_max_time: if run_time > self.table_max_time:
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s") logger.warning(
f'table recognition processing exceeds max time {self.table_max_time}s'
)
# 判断是否返回正常 # 判断是否返回正常
if html_code: if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>') expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending: if expected_ending:
res["html"] = html_code res['html'] = html_code
else: else:
logger.warning(f"table recognition processing fails, not found expected HTML table end") logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else: else:
logger.warning(f"table recognition processing fails, not get html return") logger.warning(
logger.info(f"table time: {round(time.time() - table_start, 2)}") 'table recognition processing fails, not get html return'
)
logger.info(f'table time: {round(time.time() - table_start, 2)}')
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----") logger.info(f'-----page total time: {round(time.time() - page_start, 2)}-----')
return layout_res return layout_res
from loguru import logger from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
RapidTableModel
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel StructTableModel
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
...@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): ...@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = { config = {
"model_dir": model_path, 'model_dir': model_path,
"device": _device_ 'device': _device_
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel() table_model = RapidTableModel()
else: else:
logger.error("table model type not allow") logger.error('table model type not allow')
exit(1) exit(1)
return table_model return table_model
...@@ -87,8 +92,8 @@ class AtomModelSingleton: ...@@ -87,8 +92,8 @@ class AtomModelSingleton:
return cls._instance return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs): def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get("lang", None) lang = kwargs.get('lang', None)
layout_model_name = kwargs.get("layout_model_name", None) layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang) key = (atom_model_name, layout_model_name, lang)
if key not in self._models: if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
...@@ -98,47 +103,47 @@ class AtomModelSingleton: ...@@ -98,47 +103,47 @@ class AtomModelSingleton:
def atom_model_init(model_name: str, **kwargs): def atom_model_init(model_name: str, **kwargs):
atom_model = None atom_model = None
if model_name == AtomicModel.Layout: if model_name == AtomicModel.Layout:
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3: if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init( atom_model = layout_model_init(
kwargs.get("layout_weights"), kwargs.get('layout_weights'),
kwargs.get("layout_config_file"), kwargs.get('layout_config_file'),
kwargs.get("device") kwargs.get('device')
) )
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO: elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init( atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"), kwargs.get('doclayout_yolo_weights'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.MFD: elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init( atom_model = mfd_model_init(
kwargs.get("mfd_weights"), kwargs.get('mfd_weights'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.MFR: elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init( atom_model = mfr_model_init(
kwargs.get("mfr_weight_dir"), kwargs.get('mfr_weight_dir'),
kwargs.get("mfr_cfg_path"), kwargs.get('mfr_cfg_path'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.OCR: elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init( atom_model = ocr_model_init(
kwargs.get("ocr_show_log"), kwargs.get('ocr_show_log'),
kwargs.get("det_db_box_thresh"), kwargs.get('det_db_box_thresh'),
kwargs.get("lang") kwargs.get('lang')
) )
elif model_name == AtomicModel.Table: elif model_name == AtomicModel.Table:
atom_model = table_model_init( atom_model = table_model_init(
kwargs.get("table_model_name"), kwargs.get('table_model_name'),
kwargs.get("table_model_path"), kwargs.get('table_model_path'),
kwargs.get("table_max_time"), kwargs.get('table_max_time'),
kwargs.get("device") kwargs.get('device')
) )
else: else:
logger.error("model name not allow") logger.error('model name not allow')
exit(1) exit(1)
if atom_model is None: if atom_model is None:
logger.error("model init failed") logger.error('model init failed')
exit(1) exit(1)
else: else:
return atom_model return atom_model
import os
import cv2 import cv2
import numpy as np
from paddleocr.ppstructure.table.predict_table import TableSystem from paddleocr.ppstructure.table.predict_table import TableSystem
from paddleocr.ppstructure.utility import init_args from paddleocr.ppstructure.utility import init_args
from magic_pdf.libs.Constants import *
import os
from PIL import Image from PIL import Image
import numpy as np
from magic_pdf.config.constants import * # noqa: F403
class TableMasterPaddleModel(object): class TableMasterPaddleModel(object):
""" """This class is responsible for converting image of table into HTML format
This class is responsible for converting image of table into HTML format using a pre-trained model. using a pre-trained model.
Attributes: Attributes:
- table_sys: An instance of TableSystem initialized with parsed arguments. - table_sys: An instance of TableSystem initialized with parsed arguments.
...@@ -40,30 +42,30 @@ class TableMasterPaddleModel(object): ...@@ -40,30 +42,30 @@ class TableMasterPaddleModel(object):
image = np.asarray(image) image = np.asarray(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
pred_res, _ = self.table_sys(image) pred_res, _ = self.table_sys(image)
pred_html = pred_res["html"] pred_html = pred_res['html']
# res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace( # res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
# "</table></body></html>","") + "</table></td>\n" # "</table></body></html>","") + "</table></td>\n"
return pred_html return pred_html
def parse_args(self, **kwargs): def parse_args(self, **kwargs):
parser = init_args() parser = init_args()
model_dir = kwargs.get("model_dir") model_dir = kwargs.get('model_dir')
table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR) table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR) # noqa: F405
table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT) table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT) # noqa: F405
det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR) det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR) # noqa: F405
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) # noqa: F405
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) # noqa: F405
device = kwargs.get("device", "cpu") device = kwargs.get('device', 'cpu')
use_gpu = True if device.startswith("cuda") else False use_gpu = True if device.startswith('cuda') else False
config = { config = {
"use_gpu": use_gpu, 'use_gpu': use_gpu,
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN), 'table_max_len': kwargs.get('table_max_len', TABLE_MAX_LEN), # noqa: F405
"table_algorithm": "TableMaster", 'table_algorithm': 'TableMaster',
"table_model_dir": table_model_dir, 'table_model_dir': table_model_dir,
"table_char_dict_path": table_char_dict_path, 'table_char_dict_path': table_char_dict_path,
"det_model_dir": det_model_dir, 'det_model_dir': det_model_dir,
"rec_model_dir": rec_model_dir, 'rec_model_dir': rec_model_dir,
"rec_char_dict_path": rec_char_dict_path, 'rec_char_dict_path': rec_char_dict_path,
} }
parser.set_defaults(**config) parser.set_defaults(**config)
return parser.parse_args([]) return parser.parse_args([])
from sklearn.cluster import DBSCAN
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from sklearn.cluster import DBSCAN
from magic_pdf.libs.boxbase import _is_in_or_part_overlap_with_area_ratio as is_in_layout from magic_pdf.config.ocr_content_type import ContentType
from magic_pdf.libs.ocr_content_type import ContentType from magic_pdf.libs.boxbase import \
_is_in_or_part_overlap_with_area_ratio as is_in_layout
LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?',":", ":", ")", ")", ";"] LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?', ':', ':', ')', ')', ';']
INLINE_EQUATION = ContentType.InlineEquation INLINE_EQUATION = ContentType.InlineEquation
INTERLINE_EQUATION = ContentType.InterlineEquation INTERLINE_EQUATION = ContentType.InterlineEquation
TEXT = ContentType.Text TEXT = ContentType.Text
...@@ -14,30 +14,36 @@ TEXT = ContentType.Text ...@@ -14,30 +14,36 @@ TEXT = ContentType.Text
def __get_span_text(span): def __get_span_text(span):
c = span.get('content', '') c = span.get('content', '')
if len(c)==0: if len(c) == 0:
c = span.get('image_path', '') c = span.get('image_path', '')
return c return c
def __detect_list_lines(lines, new_layout_bboxes, lang): def __detect_list_lines(lines, new_layout_bboxes, lang):
""" """探测是否包含了列表,并且把列表的行分开.
探测是否包含了列表,并且把列表的行分开.
这样的段落特点是,顶格字母大写/数字,紧跟着几行缩进的。缩进的行首字母含小写的。 这样的段落特点是,顶格字母大写/数字,紧跟着几行缩进的。缩进的行首字母含小写的。
""" """
def find_repeating_patterns(lst): def find_repeating_patterns(lst):
indices = [] indices = []
ones_indices = [] ones_indices = []
i = 0 i = 0
while i < len(lst) - 1: # 确保余下元素至少有2个 while i < len(lst) - 1: # 确保余下元素至少有2个
if lst[i] == 1 and lst[i+1] in [2, 3]: # 额外检查以防止连续出现的1 if lst[i] == 1 and lst[i + 1] in [2, 3]: # 额外检查以防止连续出现的1
start = i start = i
ones_in_this_interval = [i] ones_in_this_interval = [i]
i += 1 i += 1
while i < len(lst) and lst[i] in [2, 3]: while i < len(lst) and lst[i] in [2, 3]:
i += 1 i += 1
# 验证下一个序列是否符合条件 # 验证下一个序列是否符合条件
if i < len(lst) - 1 and lst[i] == 1 and lst[i+1] in [2, 3] and lst[i-1] in [2, 3]: if (
i < len(lst) - 1
and lst[i] == 1
and lst[i + 1] in [2, 3]
and lst[i - 1] in [2, 3]
):
while i < len(lst) and lst[i] in [1, 2, 3]: while i < len(lst) and lst[i] in [1, 2, 3]:
if lst[i] == 1: if lst[i] == 1:
ones_in_this_interval.append(i) ones_in_this_interval.append(i)
...@@ -49,7 +55,9 @@ def __detect_list_lines(lines, new_layout_bboxes, lang): ...@@ -49,7 +55,9 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
else: else:
i += 1 i += 1
return indices, ones_indices return indices, ones_indices
"""====================""" """===================="""
def split_indices(slen, index_array): def split_indices(slen, index_array):
result = [] result = []
last_end = 0 last_end = 0
...@@ -67,9 +75,10 @@ def __detect_list_lines(lines, new_layout_bboxes, lang): ...@@ -67,9 +75,10 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
result.append(('text', last_end, slen - 1)) result.append(('text', last_end, slen - 1))
return result return result
"""====================""" """===================="""
if lang!='en': if lang != 'en':
return lines, None return lines, None
else: else:
total_lines = len(lines) total_lines = len(lines)
...@@ -81,7 +90,7 @@ def __detect_list_lines(lines, new_layout_bboxes, lang): ...@@ -81,7 +90,7 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
3. 如果非顶格,首字符大写,编码为2 3. 如果非顶格,首字符大写,编码为2
4. 如果非顶格,首字符非大写编码为3 4. 如果非顶格,首字符非大写编码为3
""" """
for l in lines: for l in lines: # noqa: E741
first_char = __get_span_text(l['spans'][0])[0] first_char = __get_span_text(l['spans'][0])[0]
layout_left = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)[0] layout_left = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)[0]
if l['bbox'][0] == layout_left: if l['bbox'][0] == layout_left:
...@@ -98,42 +107,53 @@ def __detect_list_lines(lines, new_layout_bboxes, lang): ...@@ -98,42 +107,53 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
# 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。 # 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。
list_indice, list_start_idx = find_repeating_patterns(line_fea_encode) list_indice, list_start_idx = find_repeating_patterns(line_fea_encode)
if len(list_indice)>0: if len(list_indice) > 0:
logger.info(f"发现了列表,列表行数:{list_indice}{list_start_idx}") logger.info(f'发现了列表,列表行数:{list_indice}{list_start_idx}')
# TODO check一下这个特列表里缩进的行左侧是不是对齐的。 # TODO check一下这个特列表里缩进的行左侧是不是对齐的。
segments = []
for start, end in list_indice: for start, end in list_indice:
for i in range(start, end+1): for i in range(start, end + 1):
if i>0: if i > 0:
if line_fea_encode[i] == 4: if line_fea_encode[i] == 4:
logger.info(f"列表行的第{i}行不是顶格的") logger.info(f'列表行的第{i}行不是顶格的')
break break
else: else:
logger.info(f"列表行的第{start}到第{end}行是列表") logger.info(f'列表行的第{start}到第{end}行是列表')
return split_indices(total_lines, list_indice), list_start_idx return split_indices(total_lines, list_indice), list_start_idx
def __valign_lines(blocks, layout_bboxes): def __valign_lines(blocks, layout_bboxes):
""" """在一个layoutbox内对齐行的左侧和右侧。 扫描行的左侧和右侧,如果x0,
在一个layoutbox内对齐行的左侧和右侧。 x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
扫描行的左侧和右侧,如果x0, x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。 3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。"""
3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。
"""
min_distance = 3 min_distance = 3
min_sample = 2 min_sample = 2
new_layout_bboxes = [] new_layout_bboxes = []
for layout_box in layout_bboxes: for layout_box in layout_bboxes:
blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], layout_box['layout_bbox'])] blocks_in_layoutbox = [
if len(blocks_in_layoutbox)==0: b for b in blocks if is_in_layout(b['bbox'], layout_box['layout_bbox'])
]
if len(blocks_in_layoutbox) == 0:
continue continue
x0_lst = np.array([[line['bbox'][0], 0] for block in blocks_in_layoutbox for line in block['lines']]) x0_lst = np.array(
x1_lst = np.array([[line['bbox'][2], 0] for block in blocks_in_layoutbox for line in block['lines']]) [
[line['bbox'][0], 0]
for block in blocks_in_layoutbox
for line in block['lines']
]
)
x1_lst = np.array(
[
[line['bbox'][2], 0]
for block in blocks_in_layoutbox
for line in block['lines']
]
)
x0_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x0_lst) x0_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x0_lst)
x1_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x1_lst) x1_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x1_lst)
x0_uniq_label = np.unique(x0_clusters.labels_) x0_uniq_label = np.unique(x0_clusters.labels_)
...@@ -142,18 +162,18 @@ def __valign_lines(blocks, layout_bboxes): ...@@ -142,18 +162,18 @@ def __valign_lines(blocks, layout_bboxes):
x0_2_new_val = {} # 存储旧值对应的新值映射 x0_2_new_val = {} # 存储旧值对应的新值映射
x1_2_new_val = {} x1_2_new_val = {}
for label in x0_uniq_label: for label in x0_uniq_label:
if label==-1: if label == -1:
continue continue
x0_index_of_label = np.where(x0_clusters.labels_==label) x0_index_of_label = np.where(x0_clusters.labels_ == label)
x0_raw_val = x0_lst[x0_index_of_label][:,0] x0_raw_val = x0_lst[x0_index_of_label][:, 0]
x0_new_val = np.min(x0_lst[x0_index_of_label][:,0]) x0_new_val = np.min(x0_lst[x0_index_of_label][:, 0])
x0_2_new_val.update({idx: x0_new_val for idx in x0_raw_val}) x0_2_new_val.update({idx: x0_new_val for idx in x0_raw_val})
for label in x1_uniq_label: for label in x1_uniq_label:
if label==-1: if label == -1:
continue continue
x1_index_of_label = np.where(x1_clusters.labels_==label) x1_index_of_label = np.where(x1_clusters.labels_ == label)
x1_raw_val = x1_lst[x1_index_of_label][:,0] x1_raw_val = x1_lst[x1_index_of_label][:, 0]
x1_new_val = np.max(x1_lst[x1_index_of_label][:,0]) x1_new_val = np.max(x1_lst[x1_index_of_label][:, 0])
x1_2_new_val.update({idx: x1_new_val for idx in x1_raw_val}) x1_2_new_val.update({idx: x1_new_val for idx in x1_raw_val})
for block in blocks_in_layoutbox: for block in blocks_in_layoutbox:
...@@ -168,10 +188,12 @@ def __valign_lines(blocks, layout_bboxes): ...@@ -168,10 +188,12 @@ def __valign_lines(blocks, layout_bboxes):
# 由于修改了block里的line长度,现在需要重新计算block的bbox # 由于修改了block里的line长度,现在需要重新计算block的bbox
for block in blocks_in_layoutbox: for block in blocks_in_layoutbox:
block['bbox'] = [min([line['bbox'][0] for line in block['lines']]), block['bbox'] = [
min([line['bbox'][0] for line in block['lines']]),
min([line['bbox'][1] for line in block['lines']]), min([line['bbox'][1] for line in block['lines']]),
max([line['bbox'][2] for line in block['lines']]), max([line['bbox'][2] for line in block['lines']]),
max([line['bbox'][3] for line in block['lines']])] max([line['bbox'][3] for line in block['lines']]),
]
"""新计算layout的bbox,因为block的bbox变了。""" """新计算layout的bbox,因为block的bbox变了。"""
layout_x0 = min([block['bbox'][0] for block in blocks_in_layoutbox]) layout_x0 = min([block['bbox'][0] for block in blocks_in_layoutbox])
...@@ -184,13 +206,11 @@ def __valign_lines(blocks, layout_bboxes): ...@@ -184,13 +206,11 @@ def __valign_lines(blocks, layout_bboxes):
def __align_text_in_layout(blocks, layout_bboxes): def __align_text_in_layout(blocks, layout_bboxes):
""" """由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。"""
由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。
"""
for layout in layout_bboxes: for layout in layout_bboxes:
lb = layout['layout_bbox'] lb = layout['layout_bbox']
blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], lb)] blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], lb)]
if len(blocks_in_layoutbox)==0: if len(blocks_in_layoutbox) == 0:
continue continue
for block in blocks_in_layoutbox: for block in blocks_in_layoutbox:
...@@ -203,44 +223,42 @@ def __align_text_in_layout(blocks, layout_bboxes): ...@@ -203,44 +223,42 @@ def __align_text_in_layout(blocks, layout_bboxes):
def __common_pre_proc(blocks, layout_bboxes): def __common_pre_proc(blocks, layout_bboxes):
""" """不分语言的,对文本进行预处理."""
不分语言的,对文本进行预处理 # __add_line_period(blocks, layout_bboxes)
"""
#__add_line_period(blocks, layout_bboxes)
__align_text_in_layout(blocks, layout_bboxes) __align_text_in_layout(blocks, layout_bboxes)
aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes) aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes)
return aligned_layout_bboxes return aligned_layout_bboxes
def __pre_proc_zh_blocks(blocks, layout_bboxes): def __pre_proc_zh_blocks(blocks, layout_bboxes):
""" """对中文文本进行分段预处理."""
对中文文本进行分段预处理
"""
pass pass
def __pre_proc_en_blocks(blocks, layout_bboxes): def __pre_proc_en_blocks(blocks, layout_bboxes):
""" """对英文文本进行分段预处理."""
对英文文本进行分段预处理
"""
pass pass
def __group_line_by_layout(blocks, layout_bboxes, lang="en"): def __group_line_by_layout(blocks, layout_bboxes, lang='en'):
""" """每个layout内的行进行聚合."""
每个layout内的行进行聚合
"""
# 因为只是一个block一行目前, 一个block就是一个段落 # 因为只是一个block一行目前, 一个block就是一个段落
lines_group = [] lines_group = []
for lyout in layout_bboxes: for lyout in layout_bboxes:
lines = [line for block in blocks if is_in_layout(block['bbox'], lyout['layout_bbox']) for line in block['lines']] lines = [
line
for block in blocks
if is_in_layout(block['bbox'], lyout['layout_bbox'])
for line in block['lines']
]
lines_group.append(lines) lines_group.append(lines)
return lines_group return lines_group
def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_len=10): def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang='en', char_avg_len=10):
""" """
lines_group 进行行分段——layout内部进行分段。lines_group内每个元素是一个Layoutbox内的所有行。 lines_group 进行行分段——layout内部进行分段。lines_group内每个元素是一个Layoutbox内的所有行。
1. 先计算每个group的左右边界。 1. 先计算每个group的左右边界。
...@@ -256,9 +274,9 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_ ...@@ -256,9 +274,9 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
for lines in lines_group: for lines in lines_group:
paras = [] paras = []
total_lines = len(lines) total_lines = len(lines)
if total_lines==0: if total_lines == 0:
continue # 0行无需处理 continue # 0行无需处理
if total_lines==1: # 1行无法分段。 if total_lines == 1: # 1行无法分段。
layout_paras.append([lines]) layout_paras.append([lines])
list_info.append([False, False]) list_info.append([False, False])
continue continue
...@@ -272,7 +290,9 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_ ...@@ -272,7 +290,9 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
这样的文本块,顶格的为一个段落开头,紧随其后非顶格的行属于这个段落。 这样的文本块,顶格的为一个段落开头,紧随其后非顶格的行属于这个段落。
""" """
text_segments, list_start_line = __detect_list_lines(lines, new_layout_bbox, lang) text_segments, list_start_line = __detect_list_lines(
lines, new_layout_bbox, lang
)
"""根据list_range,把lines分成几个部分 """根据list_range,把lines分成几个部分
""" """
...@@ -280,50 +300,59 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_ ...@@ -280,50 +300,59 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
layout_right = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[2] layout_right = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[2]
layout_left = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[0] layout_left = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[0]
para = [] # 元素是line para = [] # 元素是line
layout_list_info = [False, False] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾 layout_list_info = [
False,
False,
] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
for content_type, start, end in text_segments: for content_type, start, end in text_segments:
if content_type == 'list': if content_type == 'list':
for i, line in enumerate(lines[start:end+1]): for i, line in enumerate(lines[start : end + 1]):
line_x0 = line['bbox'][0] line_x0 = line['bbox'][0]
if line_x0 == layout_left: # 列表开头 if line_x0 == layout_left: # 列表开头
if len(para)>0: if len(para) > 0:
paras.append(para) paras.append(para)
para = [] para = []
para.append(line) para.append(line)
else: else:
para.append(line) para.append(line)
if len(para)>0: if len(para) > 0:
paras.append(para) paras.append(para)
para = [] para = []
if start==0: if start == 0:
layout_list_info[0] = True layout_list_info[0] = True
if end==total_lines-1: if end == total_lines - 1:
layout_list_info[1] = True layout_list_info[1] = True
else: # 是普通文本 else: # 是普通文本
for i, line in enumerate(lines[start:end+1]): for i, line in enumerate(lines[start : end + 1]):
# 如果i有下一行,那么就要根据下一行位置综合判断是否要分段。如果i之后没有行,那么只需要判断i行自己的结尾特征。 # 如果i有下一行,那么就要根据下一行位置综合判断是否要分段。如果i之后没有行,那么只需要判断i行自己的结尾特征。
cur_line_type = line['spans'][-1]['type'] cur_line_type = line['spans'][-1]['type']
next_line = lines[i+1] if i<total_lines-1 else None next_line = lines[i + 1] if i < total_lines - 1 else None
if cur_line_type in [TEXT, INLINE_EQUATION]: if cur_line_type in [TEXT, INLINE_EQUATION]:
if line['bbox'][2] < layout_right - right_tail_distance: if line['bbox'][2] < layout_right - right_tail_distance:
para.append(line) para.append(line)
paras.append(para) paras.append(para)
para = [] para = []
elif line['bbox'][2] >= layout_right - right_tail_distance and next_line and next_line['bbox'][0] == layout_left: # 现在这行到了行尾沾满,下一行存在且顶格。 elif (
line['bbox'][2] >= layout_right - right_tail_distance
and next_line
and next_line['bbox'][0] == layout_left
): # 现在这行到了行尾沾满,下一行存在且顶格。
para.append(line) para.append(line)
else: else:
para.append(line) para.append(line)
paras.append(para) paras.append(para)
para = [] para = []
else: # 其他,图片、表格、行间公式,各自占一段 else: # 其他,图片、表格、行间公式,各自占一段
if len(para)>0: # 先把之前的段落加入到结果中 if len(para) > 0: # 先把之前的段落加入到结果中
paras.append(para) paras.append(para)
para = [] para = []
paras.append([line]) # 再把当前行加入到结果中。当前行为行间公式、图、表等。 paras.append(
[line]
) # 再把当前行加入到结果中。当前行为行间公式、图、表等。
para = [] para = []
if len(para)>0: if len(para) > 0:
paras.append(para) paras.append(para)
para = [] para = []
...@@ -331,79 +360,112 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_ ...@@ -331,79 +360,112 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
layout_paras.append(paras) layout_paras.append(paras)
paras = [] paras = []
return layout_paras, list_info return layout_paras, list_info
def __connect_list_inter_layout(layout_paras, new_layout_bbox, layout_list_info, page_num, lang):
""" def __connect_list_inter_layout(
如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。 layout_paras, new_layout_bbox, layout_list_info, page_num, lang
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。 ):
""" """如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO
if len(layout_paras)==0 or len(layout_list_info)==0: # 0的时候最后的return 会出错 因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。"""
if (
len(layout_paras) == 0 or len(layout_list_info) == 0
): # 0的时候最后的return 会出错
return layout_paras, [False, False] return layout_paras, [False, False]
for i in range(1, len(layout_paras)): for i in range(1, len(layout_paras)):
pre_layout_list_info = layout_list_info[i-1] pre_layout_list_info = layout_list_info[i - 1]
next_layout_list_info = layout_list_info[i] next_layout_list_info = layout_list_info[i]
pre_last_para = layout_paras[i-1][-1] pre_last_para = layout_paras[i - 1][-1]
next_paras = layout_paras[i] next_paras = layout_paras[i]
next_first_para = next_paras[0]
if pre_layout_list_info[1] and not next_layout_list_info[0]: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进 if (
logger.info(f"连接page {page_num} 内的list") pre_layout_list_info[1] and not next_layout_list_info[0]
): # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
logger.info(f'连接page {page_num} 内的list')
# 向layout_paras[i] 寻找开头具有相同缩进的连续的行 # 向layout_paras[i] 寻找开头具有相同缩进的连续的行
may_list_lines = [] may_list_lines = []
for j in range(len(next_paras)): for j in range(len(next_paras)):
line = next_paras[j] line = next_paras[j]
if len(line)==1: # 只可能是一行,多行情况再需要分析了 if len(line) == 1: # 只可能是一行,多行情况再需要分析了
if line[0]['bbox'][0] > __find_layout_bbox_by_line(line[0]['bbox'], new_layout_bbox)[0]: if (
line[0]['bbox'][0]
> __find_layout_bbox_by_line(line[0]['bbox'], new_layout_bbox)[
0
]
):
may_list_lines.append(line[0]) may_list_lines.append(line[0])
else: else:
break break
else: else:
break break
# 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。 # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
if len(may_list_lines)>0 and len(set([x['bbox'][0] for x in may_list_lines]))==1: if (
len(may_list_lines) > 0
and len(set([x['bbox'][0] for x in may_list_lines])) == 1
):
pre_last_para.extend(may_list_lines) pre_last_para.extend(may_list_lines)
layout_paras[i] = layout_paras[i][len(may_list_lines):] layout_paras[i] = layout_paras[i][len(may_list_lines) :]
return layout_paras, [layout_list_info[0][0], layout_list_info[-1][1]] # 同时还返回了这个页面级别的开头、结尾是不是列表的信息 return layout_paras, [
layout_list_info[0][0],
layout_list_info[-1][1],
def __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, pre_page_list_info, next_page_list_info, page_num, lang): ] # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
"""
如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。 def __connect_list_inter_page(
""" pre_page_paras,
if len(pre_page_paras)==0 or len(next_page_paras)==0: # 0的时候最后的return 会出错 next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
pre_page_list_info,
next_page_list_info,
page_num,
lang,
):
"""如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO
因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。"""
if (
len(pre_page_paras) == 0 or len(next_page_paras) == 0
): # 0的时候最后的return 会出错
return False return False
if pre_page_list_info[1] and not next_page_list_info[0]: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进 if (
logger.info(f"连接page {page_num} 内的list") pre_page_list_info[1] and not next_page_list_info[0]
): # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
logger.info(f'连接page {page_num} 内的list')
# 向layout_paras[i] 寻找开头具有相同缩进的连续的行 # 向layout_paras[i] 寻找开头具有相同缩进的连续的行
may_list_lines = [] may_list_lines = []
for j in range(len(next_page_paras[0])): for j in range(len(next_page_paras[0])):
line = next_page_paras[0][j] line = next_page_paras[0][j]
if len(line)==1: # 只可能是一行,多行情况再需要分析了 if len(line) == 1: # 只可能是一行,多行情况再需要分析了
if line[0]['bbox'][0] > __find_layout_bbox_by_line(line[0]['bbox'], next_page_layout_bbox)[0]: if (
line[0]['bbox'][0]
> __find_layout_bbox_by_line(
line[0]['bbox'], next_page_layout_bbox
)[0]
):
may_list_lines.append(line[0]) may_list_lines.append(line[0])
else: else:
break break
else: else:
break break
# 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。 # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
if len(may_list_lines)>0 and len(set([x['bbox'][0] for x in may_list_lines]))==1: if (
len(may_list_lines) > 0
and len(set([x['bbox'][0] for x in may_list_lines])) == 1
):
pre_page_paras[-1].append(may_list_lines) pre_page_paras[-1].append(may_list_lines)
next_page_paras[0] = next_page_paras[0][len(may_list_lines):] next_page_paras[0] = next_page_paras[0][len(may_list_lines) :]
return True return True
return False return False
def __find_layout_bbox_by_line(line_bbox, layout_bboxes): def __find_layout_bbox_by_line(line_bbox, layout_bboxes):
""" """根据line找到所在的layout."""
根据line找到所在的layout
"""
for layout in layout_bboxes: for layout in layout_bboxes:
if is_in_layout(line_bbox, layout): if is_in_layout(line_bbox, layout):
return layout return layout
...@@ -420,37 +482,56 @@ def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang): ...@@ -420,37 +482,56 @@ def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang):
""" """
connected_layout_paras = [] connected_layout_paras = []
if len(layout_paras)==0: if len(layout_paras) == 0:
return connected_layout_paras return connected_layout_paras
connected_layout_paras.append(layout_paras[0]) connected_layout_paras.append(layout_paras[0])
for i in range(1, len(layout_paras)): for i in range(1, len(layout_paras)):
try: try:
if len(layout_paras[i])==0 or len(layout_paras[i-1])==0: # TODO 考虑连接问题, if (
len(layout_paras[i]) == 0 or len(layout_paras[i - 1]) == 0
): # TODO 考虑连接问题,
continue continue
pre_last_line = layout_paras[i-1][-1][-1] pre_last_line = layout_paras[i - 1][-1][-1]
next_first_line = layout_paras[i][0][0] next_first_line = layout_paras[i][0][0]
except Exception as e: except Exception:
logger.error(f"page layout {i} has no line") logger.error(f'page layout {i} has no line')
continue continue
pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']]) pre_last_line_text = ''.join(
[__get_span_text(span) for span in pre_last_line['spans']]
)
pre_last_line_type = pre_last_line['spans'][-1]['type'] pre_last_line_type = pre_last_line['spans'][-1]['type']
next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']]) next_first_line_text = ''.join(
[__get_span_text(span) for span in next_first_line['spans']]
)
next_first_line_type = next_first_line['spans'][0]['type'] next_first_line_type = next_first_line['spans'][0]['type']
if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]: if pre_last_line_type not in [
TEXT,
INLINE_EQUATION,
] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
connected_layout_paras.append(layout_paras[i]) connected_layout_paras.append(layout_paras[i])
continue continue
pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)[2] pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)[
next_x0_min = __find_layout_bbox_by_line(next_first_line['bbox'], new_layout_bbox)[0] 2
]
next_x0_min = __find_layout_bbox_by_line(
next_first_line['bbox'], new_layout_bbox
)[0]
pre_last_line_text = pre_last_line_text.strip() pre_last_line_text = pre_last_line_text.strip()
next_first_line_text = next_first_line_text.strip() next_first_line_text = next_first_line_text.strip()
if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and next_first_line['bbox'][0]==next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。 if (
pre_last_line['bbox'][2] == pre_x2_max
and pre_last_line_text[-1] not in LINE_STOP_FLAG
and next_first_line['bbox'][0] == next_x0_min
): # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
"""连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。""" """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
connected_layout_paras[-1][-1].extend(layout_paras[i][0]) connected_layout_paras[-1][-1].extend(layout_paras[i][0])
layout_paras[i].pop(0) # 删除后一个layout的第一个段落, 因为他已经被合并到前一个layout的最后一个段落了。 layout_paras[i].pop(
if len(layout_paras[i])==0: 0
) # 删除后一个layout的第一个段落, 因为他已经被合并到前一个layout的最后一个段落了。
if len(layout_paras[i]) == 0:
layout_paras.pop(i) layout_paras.pop(i)
else: else:
connected_layout_paras.append(layout_paras[i]) connected_layout_paras.append(layout_paras[i])
...@@ -461,7 +542,14 @@ def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang): ...@@ -461,7 +542,14 @@ def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang):
return connected_layout_paras return connected_layout_paras
def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num, lang): def __connect_para_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
page_num,
lang,
):
""" """
连接起来相邻两个页面的段落——前一个页面最后一个段落和后一个页面的第一个段落。 连接起来相邻两个页面的段落——前一个页面最后一个段落和后一个页面的第一个段落。
是否可以连接的条件: 是否可以连接的条件:
...@@ -469,34 +557,60 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b ...@@ -469,34 +557,60 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
2. 后一个页面的第一个段落第一行没有空白开头。 2. 后一个页面的第一个段落第一行没有空白开头。
""" """
# 有的页面可能压根没有文字 # 有的页面可能压根没有文字
if len(pre_page_paras)==0 or len(next_page_paras)==0 or len(pre_page_paras[0])==0 or len(next_page_paras[0])==0: # TODO [[]]为什么出现在pre_page_paras里? if (
len(pre_page_paras) == 0
or len(next_page_paras) == 0
or len(pre_page_paras[0]) == 0
or len(next_page_paras[0]) == 0
): # TODO [[]]为什么出现在pre_page_paras里?
return False return False
pre_last_para = pre_page_paras[-1][-1] pre_last_para = pre_page_paras[-1][-1]
next_first_para = next_page_paras[0][0] next_first_para = next_page_paras[0][0]
pre_last_line = pre_last_para[-1] pre_last_line = pre_last_para[-1]
next_first_line = next_first_para[0] next_first_line = next_first_para[0]
pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']]) pre_last_line_text = ''.join(
[__get_span_text(span) for span in pre_last_line['spans']]
)
pre_last_line_type = pre_last_line['spans'][-1]['type'] pre_last_line_type = pre_last_line['spans'][-1]['type']
next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']]) next_first_line_text = ''.join(
[__get_span_text(span) for span in next_first_line['spans']]
)
next_first_line_type = next_first_line['spans'][0]['type'] next_first_line_type = next_first_line['spans'][0]['type']
if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]: # TODO,真的要做好,要考虑跨table, image, 行间的情况 if pre_last_line_type not in [
TEXT,
INLINE_EQUATION,
] or next_first_line_type not in [
TEXT,
INLINE_EQUATION,
]: # TODO,真的要做好,要考虑跨table, image, 行间的情况
# 不是文本,不连接 # 不是文本,不连接
return False return False
pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], pre_page_layout_bbox)[2] pre_x2_max = __find_layout_bbox_by_line(
next_x0_min = __find_layout_bbox_by_line(next_first_line['bbox'], next_page_layout_bbox)[0] pre_last_line['bbox'], pre_page_layout_bbox
)[2]
next_x0_min = __find_layout_bbox_by_line(
next_first_line['bbox'], next_page_layout_bbox
)[0]
pre_last_line_text = pre_last_line_text.strip() pre_last_line_text = pre_last_line_text.strip()
next_first_line_text = next_first_line_text.strip() next_first_line_text = next_first_line_text.strip()
if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and next_first_line['bbox'][0]==next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。 if (
pre_last_line['bbox'][2] == pre_x2_max
and pre_last_line_text[-1] not in LINE_STOP_FLAG
and next_first_line['bbox'][0] == next_x0_min
): # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
"""连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。""" """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
pre_last_para.extend(next_first_para) pre_last_para.extend(next_first_para)
next_page_paras[0].pop(0) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。 next_page_paras[0].pop(
0
) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
return True return True
else: else:
return False return False
def find_consecutive_true_regions(input_array): def find_consecutive_true_regions(input_array):
start_index = None # 连续True区域的起始索引 start_index = None # 连续True区域的起始索引
regions = [] # 用于保存所有连续True区域的起始和结束索引 regions = [] # 用于保存所有连续True区域的起始和结束索引
...@@ -510,17 +624,19 @@ def find_consecutive_true_regions(input_array): ...@@ -510,17 +624,19 @@ def find_consecutive_true_regions(input_array):
elif not input_array[i] and start_index is not None: elif not input_array[i] and start_index is not None:
# 如果连续True区域长度大于1,那么将其添加到结果列表中 # 如果连续True区域长度大于1,那么将其添加到结果列表中
if i - start_index > 1: if i - start_index > 1:
regions.append((start_index, i-1)) regions.append((start_index, i - 1))
start_index = None # 重置起始索引 start_index = None # 重置起始索引
# 如果最后一个元素是True,那么需要将最后一个连续True区域加入到结果列表中 # 如果最后一个元素是True,那么需要将最后一个连续True区域加入到结果列表中
if start_index is not None and len(input_array) - start_index > 1: if start_index is not None and len(input_array) - start_index > 1:
regions.append((start_index, len(input_array)-1)) regions.append((start_index, len(input_array) - 1))
return regions return regions
def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, debug_mode): def __connect_middle_align_text(
page_paras, new_layout_bbox, page_num, lang, debug_mode
):
""" """
找出来中间对齐的连续单行文本,如果连续行高度相同,那么合并为一个段落。 找出来中间对齐的连续单行文本,如果连续行高度相同,那么合并为一个段落。
一个line居中的条件是: 一个line居中的条件是:
...@@ -532,54 +648,78 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, deb ...@@ -532,54 +648,78 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, deb
layout_box = new_layout_bbox[layout_i] layout_box = new_layout_bbox[layout_i]
single_line_paras_tag = [] single_line_paras_tag = []
for i in range(len(layout_para)): for i in range(len(layout_para)):
single_line_paras_tag.append(len(layout_para[i])==1 and layout_para[i][0]['spans'][0]['type']==TEXT) single_line_paras_tag.append(
len(layout_para[i]) == 1
and layout_para[i][0]['spans'][0]['type'] == TEXT
)
"""找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。""" """找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。"""
consecutive_single_line_indices = find_consecutive_true_regions(single_line_paras_tag) consecutive_single_line_indices = find_consecutive_true_regions(
if len(consecutive_single_line_indices)>0: single_line_paras_tag
)
if len(consecutive_single_line_indices) > 0:
index_offset = 0 index_offset = 0
"""检查这些行是否是高度相同的,居中的""" """检查这些行是否是高度相同的,居中的"""
for start, end in consecutive_single_line_indices: for start, end in consecutive_single_line_indices:
start += index_offset start += index_offset
end += index_offset end += index_offset
line_hi = np.array([line[0]['bbox'][3]-line[0]['bbox'][1] for line in layout_para[start:end+1]]) line_hi = np.array(
first_line_text = ''.join([__get_span_text(span) for span in layout_para[start][0]['spans']]) [
if "Table" in first_line_text or "Figure" in first_line_text: line[0]['bbox'][3] - line[0]['bbox'][1]
for line in layout_para[start : end + 1]
]
)
first_line_text = ''.join(
[__get_span_text(span) for span in layout_para[start][0]['spans']]
)
if 'Table' in first_line_text or 'Figure' in first_line_text:
pass pass
if debug_mode: if debug_mode:
logger.debug(line_hi.std()) logger.debug(line_hi.std())
if line_hi.std()<2: if line_hi.std() < 2:
"""行高度相同,那么判断是否居中""" """行高度相同,那么判断是否居中."""
all_left_x0 = [line[0]['bbox'][0] for line in layout_para[start:end+1]] all_left_x0 = [
all_right_x1 = [line[0]['bbox'][2] for line in layout_para[start:end+1]] line[0]['bbox'][0] for line in layout_para[start : end + 1]
]
all_right_x1 = [
line[0]['bbox'][2] for line in layout_para[start : end + 1]
]
layout_center = (layout_box[0] + layout_box[2]) / 2 layout_center = (layout_box[0] + layout_box[2]) / 2
if all([x0 < layout_center < x1 for x0, x1 in zip(all_left_x0, all_right_x1)]) \ if (
and not all([x0==layout_box[0] for x0 in all_left_x0]) \ all(
and not all([x1==layout_box[2] for x1 in all_right_x1]): [
merge_para = [l[0] for l in layout_para[start:end+1]] x0 < layout_center < x1
para_text = ''.join([__get_span_text(span) for line in merge_para for span in line['spans']]) for x0, x1 in zip(all_left_x0, all_right_x1)
]
)
and not all([x0 == layout_box[0] for x0 in all_left_x0])
and not all([x1 == layout_box[2] for x1 in all_right_x1])
):
merge_para = [l[0] for l in layout_para[start : end + 1]] # noqa: E741
para_text = ''.join(
[
__get_span_text(span)
for line in merge_para
for span in line['spans']
]
)
if debug_mode: if debug_mode:
logger.debug(para_text) logger.debug(para_text)
layout_para[start:end+1] = [merge_para] layout_para[start : end + 1] = [merge_para]
index_offset -= end-start index_offset -= end - start
return return
def __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang): def __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang):
""" """找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。"""
找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。
"""
pass pass
def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang): def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
""" """根据line和layout情况进行分段 先实现一个根据行末尾特征分段的简单方法。"""
根据line和layout情况进行分段
先实现一个根据行末尾特征分段的简单方法。
"""
""" """
算法思路: 算法思路:
1. 扫描layout里每一行,找出来行尾距离layout有边界有一定距离的行。 1. 扫描layout里每一行,找出来行尾距离layout有边界有一定距离的行。
...@@ -587,21 +727,24 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang): ...@@ -587,21 +727,24 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
3. 参照上述行尾特征进行分段。 3. 参照上述行尾特征进行分段。
4. 图、表,目前独占一行,不考虑分段。 4. 图、表,目前独占一行,不考虑分段。
""" """
if page_num==343: if page_num == 343:
pass pass
lines_group = __group_line_by_layout(blocks, layout_bboxes, lang) # block内分段 lines_group = __group_line_by_layout(blocks, layout_bboxes, lang) # block内分段
layout_paras, layout_list_info = __split_para_in_layoutbox(lines_group, new_layout_bbox, lang) # layout内分段 layout_paras, layout_list_info = __split_para_in_layoutbox(
layout_paras2, page_list_info = __connect_list_inter_layout(layout_paras, new_layout_bbox, layout_list_info, page_num, lang) # layout之间连接列表段落 lines_group, new_layout_bbox, lang
connected_layout_paras = __connect_para_inter_layoutbox(layout_paras2, new_layout_bbox, lang) # layout间链接段落 ) # layout内分段
layout_paras2, page_list_info = __connect_list_inter_layout(
layout_paras, new_layout_bbox, layout_list_info, page_num, lang
) # layout之间连接列表段落
connected_layout_paras = __connect_para_inter_layoutbox(
layout_paras2, new_layout_bbox, lang
) # layout间链接段落
return connected_layout_paras, page_list_info return connected_layout_paras, page_list_info
def para_split(pdf_info_dict, debug_mode, lang="en"): def para_split(pdf_info_dict, debug_mode, lang='en'):
""" """根据line和layout情况进行分段."""
根据line和layout情况进行分段
"""
new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS
all_page_list_info = [] # 保存每个页面开头和结尾是否是列表 all_page_list_info = [] # 保存每个页面开头和结尾是否是列表
for page_num, page in pdf_info_dict.items(): for page_num, page in pdf_info_dict.items():
...@@ -609,29 +752,47 @@ def para_split(pdf_info_dict, debug_mode, lang="en"): ...@@ -609,29 +752,47 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
layout_bboxes = page['layout_bboxes'] layout_bboxes = page['layout_bboxes']
new_layout_bbox = __common_pre_proc(blocks, layout_bboxes) new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
new_layout_of_pages.append(new_layout_bbox) new_layout_of_pages.append(new_layout_bbox)
splited_blocks, page_list_info = __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang) splited_blocks, page_list_info = __do_split_page(
blocks, layout_bboxes, new_layout_bbox, page_num, lang
)
all_page_list_info.append(page_list_info) all_page_list_info.append(page_list_info)
page['para_blocks'] = splited_blocks page['para_blocks'] = splited_blocks
"""连接页面与页面之间的可能合并的段落""" """连接页面与页面之间的可能合并的段落"""
pdf_infos = list(pdf_info_dict.values()) pdf_infos = list(pdf_info_dict.values())
for page_num, page in enumerate(pdf_info_dict.values()): for page_num, page in enumerate(pdf_info_dict.values()):
if page_num==0: if page_num == 0:
continue continue
pre_page_paras = pdf_infos[page_num-1]['para_blocks'] pre_page_paras = pdf_infos[page_num - 1]['para_blocks']
next_page_paras = pdf_infos[page_num]['para_blocks'] next_page_paras = pdf_infos[page_num]['para_blocks']
pre_page_layout_bbox = new_layout_of_pages[page_num-1] pre_page_layout_bbox = new_layout_of_pages[page_num - 1]
next_page_layout_bbox = new_layout_of_pages[page_num] next_page_layout_bbox = new_layout_of_pages[page_num]
is_conn = __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num, lang) is_conn = __connect_para_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
page_num,
lang,
)
if debug_mode: if debug_mode:
if is_conn: if is_conn:
logger.info(f"连接了第{page_num-1}页和第{page_num}页的段落") logger.info(f'连接了第{page_num-1}页和第{page_num}页的段落')
is_list_conn = __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, all_page_list_info[page_num-1], all_page_list_info[page_num], page_num, lang) is_list_conn = __connect_list_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
all_page_list_info[page_num - 1],
all_page_list_info[page_num],
page_num,
lang,
)
if debug_mode: if debug_mode:
if is_list_conn: if is_list_conn:
logger.info(f"连接了第{page_num-1}页和第{page_num}页的列表段落") logger.info(f'连接了第{page_num-1}页和第{page_num}页的列表段落')
"""接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接 """接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
1. 正文中有时出现一个行顶格,接下来几行缩进的情况。 1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
...@@ -640,5 +801,7 @@ def para_split(pdf_info_dict, debug_mode, lang="en"): ...@@ -640,5 +801,7 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
for page_num, page in enumerate(pdf_info_dict.values()): for page_num, page in enumerate(pdf_info_dict.values()):
page_paras = page['para_blocks'] page_paras = page['para_blocks']
new_layout_bbox = new_layout_of_pages[page_num] new_layout_bbox = new_layout_of_pages[page_num]
__connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, debug_mode=debug_mode) __connect_middle_align_text(
page_paras, new_layout_bbox, page_num, lang, debug_mode=debug_mode
)
__merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang) __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment