Unverified Commit ea2f8ea0 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge branch 'dev' into dev

parents e4810cee 23c8436e
"""
span维度自定义字段
"""
"""span维度自定义字段."""
# span是否是跨页合并的
CROSS_PAGE = "cross_page"
CROSS_PAGE = 'cross_page'
"""
block维度自定义字段
"""
# block中lines是否被删除
LINES_DELETED = "lines_deleted"
LINES_DELETED = 'lines_deleted'
# table recognition max time default value
TABLE_MAX_TIME_VALUE = 400
......@@ -17,39 +15,39 @@ TABLE_MAX_TIME_VALUE = 400
TABLE_MAX_LEN = 480
# table master structure dict
TABLE_MASTER_DICT = "table_master_structure_dict.txt"
TABLE_MASTER_DICT = 'table_master_structure_dict.txt'
# table master dir
TABLE_MASTER_DIR = "table_structure_tablemaster_infer/"
TABLE_MASTER_DIR = 'table_structure_tablemaster_infer/'
# pp detect model dir
DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer"
DETECT_MODEL_DIR = 'ch_PP-OCRv4_det_infer'
# pp rec model dir
REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer"
REC_MODEL_DIR = 'ch_PP-OCRv4_rec_infer'
# pp rec char dict path
REC_CHAR_DICT = "ppocr_keys_v1.txt"
REC_CHAR_DICT = 'ppocr_keys_v1.txt'
# pp rec copy rec directory
PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer"
PP_REC_DIRECTORY = '.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer'
# pp rec copy det directory
PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer"
PP_DET_DIRECTORY = '.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer'
class MODEL_NAME:
# pp table structure algorithm
TABLE_MASTER = "tablemaster"
TABLE_MASTER = 'tablemaster'
# struct eqtable
STRUCT_EQTABLE = "struct_eqtable"
STRUCT_EQTABLE = 'struct_eqtable'
DocLayout_YOLO = "doclayout_yolo"
DocLayout_YOLO = 'doclayout_yolo'
LAYOUTLMv3 = "layoutlmv3"
LAYOUTLMv3 = 'layoutlmv3'
YOLO_V8_MFD = "yolo_v8_mfd"
YOLO_V8_MFD = 'yolo_v8_mfd'
UniMerNet_v2_Small = "unimernet_small"
UniMerNet_v2_Small = 'unimernet_small'
RAPID_TABLE = "rapid_table"
\ No newline at end of file
RAPID_TABLE = 'rapid_table'
class DropReason:
TEXT_BLCOK_HOR_OVERLAP = 'text_block_horizontal_overlap' # 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP = (
'useful_block_horizontal_overlap' # 需保留的block水平覆盖
)
COMPLICATED_LAYOUT = 'complicated_layout' # 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS = 'too_many_layout_columns' # 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX = 'color_background_text_box' # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS = (
'high_computational_load_by_imgs' # 含特殊图片,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = (
'high_computational_load_by_svgs' # 特殊的SVG图,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = 'high_computational_load_by_total_pages' # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = 'missing doc_layout_result' # 版面分析失败
Exception = '_exception' # 解析中发生异常
ENCRYPTED = 'encrypted' # PDF是加密的
EMPTY_PDF = 'total_page=0' # PDF页面总数为0
NOT_IS_TEXT_PDF = 'not_is_text_pdf' # 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK = 'dense_single_line_block' # 无法清晰的分段
TITLE_DETECTION_FAILED = 'title_detection_failed' # 探测标题失败
TITLE_LEVEL_FAILED = (
'title_level_failed' # 分析标题级别失败(例如一级、二级、三级标题)
)
PARA_SPLIT_FAILED = 'para_split_failed' # 识别段落失败
PARA_MERGE_FAILED = 'para_merge_failed' # 段落合并失败
NOT_ALLOW_LANGUAGE = 'not_allow_language' # 不支持的语种
SPECIAL_PDF = 'special_pdf'
PSEUDO_SINGLE_COLUMN = 'pseudo_single_column' # 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT = 'can_not_detect_page_layout' # 无法分析页面的版面
NEGATIVE_BBOX_AREA = 'negative_bbox_area' # 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION = (
'overlap_blocks_can_t_separation' # 无法分离重叠的block
)
COLOR_BG_HEADER_TXT_BLOCK = 'color_background_header_txt_block'
PAGE_NO = 'page-no' # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = 'page_no'
HEADER = 'header'
FOOTER = 'footer'
FOOTNOTE = 'footnote'
NOT_IN_LAYOUT = 'not_in_layout'
SPAN_OVERLAP = 'span_overlap'
BLOCK_OVERLAP = 'block_overlap'
class MakeMode:
MM_MD = 'mm_markdown'
NLP_MD = 'nlp_markdown'
STANDARD_FORMAT = 'standard_format'
class DropMode:
WHOLE_PDF = 'whole_pdf'
SINGLE_PAGE = 'single_page'
NONE = 'none'
NONE_WITH_REASON = 'none_with_reason'
from enum import Enum
class ModelBlockTypeEnum(Enum):
TITLE = 0
PLAIN_TEXT = 1
......
......@@ -35,7 +35,7 @@ def read_jsonl(
jsonl_d = [
json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
]
for d in jsonl_d[:5]:
for d in jsonl_d:
pdf_path = d.get('file_location', '') or d.get('path', '')
if len(pdf_path) == 0:
raise EmptyData('pdf file location is empty')
......
This diff is collapsed.
......@@ -2,17 +2,16 @@ import re
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.commons import join_path
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.para.para_split_v3 import ListLineTag
def __is_hyphen_at_line_end(line):
"""
Check if a line ends with one or more letters followed by a hyphen.
"""Check if a line ends with one or more letters followed by a hyphen.
Args:
line (str): The line of text to check.
......@@ -163,7 +162,7 @@ def merge_para_with_text(para_block):
if span_type in [ContentType.Text, ContentType.InterlineEquation]:
para_text += content # 中文/日语/韩文语境下,content间不需要空格分隔
elif span_type == ContentType.InlineEquation:
para_text += f" {content} "
para_text += f' {content} '
else:
if span_type in [ContentType.Text, ContentType.InlineEquation]:
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
......@@ -172,7 +171,7 @@ def merge_para_with_text(para_block):
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
para_text += content
else: # 西方文本语境下 content间需要空格分隔
para_text += f"{content} "
para_text += f'{content} '
elif span_type == ContentType.InterlineEquation:
para_text += content
else:
......
"""
输入: s3路径,每行一个
输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置
"""
"""输入: s3路径,每行一个 输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置."""
import sys
import click
from collections import Counter
from magic_pdf.libs.commons import read_file, mymax, get_top_percent_list
from magic_pdf.libs.commons import fitz
import click
from loguru import logger
from collections import Counter
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.libs.commons import fitz, get_top_percent_list, mymax, read_file
from magic_pdf.libs.language import detect_lang
from magic_pdf.libs.pdf_check import detect_invalid_chars
......@@ -19,8 +16,10 @@ junk_limit_min = 10
def calculate_max_image_area_per_page(result: list, page_width_pts, page_height_pts):
max_image_area_per_page = [mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz]) for page_img_sz in
result]
max_image_area_per_page = [
mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz])
for page_img_sz in result
]
page_area = int(page_width_pts) * int(page_height_pts)
max_image_area_per_page = [area / page_area for area in max_image_area_per_page]
max_image_area_per_page = [area for area in max_image_area_per_page if area > 0.6]
......@@ -33,7 +32,9 @@ def process_image(page, junk_img_bojids=[]):
dedup = set()
for img in items:
# 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
img_bojid = img[0] # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
img_bojid = img[
0
] # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
if img_bojid in junk_img_bojids: # 如果是垃圾图像,就跳过
continue
recs = page.get_image_rects(img, transform=True)
......@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]):
x0, y0, x1, y1 = map(int, rec)
width = x1 - x0
height = y1 - y0
if (x0, y0, x1, y1, img_bojid) in dedup: # 这里面会出现一些重复的bbox,无需重复出现,需要去掉
if (
x0,
y0,
x1,
y1,
img_bojid,
) in dedup: # 这里面会出现一些重复的bbox,无需重复出现,需要去掉
continue
if not all([width, height]): # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
if not all(
[width, height]
): # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
continue
dedup.add((x0, y0, x1, y1, img_bojid))
page_result.append([x0, y0, x1, y1, img_bojid])
......@@ -52,8 +61,8 @@ def process_image(page, junk_img_bojids=[]):
def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
"""
返回每个页面里的图片的四元组,每个页面多个图片。
"""返回每个页面里的图片的四元组,每个页面多个图片。
:param doc:
:return:
"""
......@@ -63,13 +72,17 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
junk_limit = max(len(doc) * 0.5, junk_limit_min) # 对一些页数比较少的进行豁免
junk_img_bojids = [img_bojid for img_bojid, count in img_bojid_counter.items() if count >= junk_limit]
#todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
#有两种扫描版,一种文字版,这里可能会有误判
#扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
#扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
#文字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
junk_img_bojids = [
img_bojid
for img_bojid, count in img_bojid_counter.items()
if count >= junk_limit
]
# todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
# 有两种扫描版,一种文字版,这里可能会有误判
# 扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
# 扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
# 文 字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
imgs_len_list = [len(page.get_images()) for page in doc]
special_limit_pages = 10
......@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
break
if i >= special_limit_pages:
break
page_result = process_image(page) # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
page_result = process_image(
page
) # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
result.append(page_result)
for item in result:
if not any(item): # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if max(imgs_len_list) == min(imgs_len_list) and max(
imgs_len_list) >= junk_limit_min: # 如果是特殊文字版,就把junklist置空并break
if not any(
item
): # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if (
max(imgs_len_list) == min(imgs_len_list)
and max(imgs_len_list) >= junk_limit_min
): # 如果是特殊文字版,就把junklist置空并break
junk_img_bojids = []
else: # 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist
pass
......@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
top_eighty_percent = get_top_percent_list(imgs_len_list, 0.8)
# 检查前80%的元素是否都相等
if len(set(top_eighty_percent)) == 1 and max(imgs_len_list) >= junk_limit_min:
# # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist
# if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min:
#前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
max_image_area_per_page = calculate_max_image_area_per_page(result, page_width_pts, page_height_pts)
if len(max_image_area_per_page) < 0.8 * special_limit_pages: # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
# 前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
max_image_area_per_page = calculate_max_image_area_per_page(
result, page_width_pts, page_height_pts
)
if (
len(max_image_area_per_page) < 0.8 * special_limit_pages
): # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
junk_img_bojids = []
else: # 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist
pass
else: # 每页图片数量不一致,需要清掉junklist全量跑前50页图片
junk_img_bojids = []
#正式进入取前50页图片的信息流程
# 正式进入取前50页图片的信息流程
result = []
for i, page in enumerate(doc):
if i >= scan_max_page:
......@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
def get_pdf_page_size_pts(doc: fitz.Document):
page_cnt = len(doc)
l: int = min(page_cnt, 50)
#把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
# 把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
page_width_list = []
page_height_list = []
for i in range(l):
......@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
# 拿所有text的blocks
# text_block = page.get_text("words")
# text_block_len = sum([len(t[4]) for t in text_block])
#拿所有text的str
text_block = page.get_text("text")
# 拿所有text的str
text_block = page.get_text('text')
text_block_len = len(text_block)
# logger.info(f"page {page.number} text_block_len: {text_block_len}")
text_len_lst.append(text_block_len)
......@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
def get_pdf_text_layout_per_page(doc: fitz.Document):
"""
根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
"""根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
Args:
doc (fitz.Document): PDF文档对象。
Returns:
List[str]: 每一页的文本布局(横向、纵向、未知)。
"""
text_layout_list = []
......@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# 创建每一页的纵向和横向的文本行数计数器
vertical_count = 0
horizontal_count = 0
text_dict = page.get_text("dict")
if "blocks" in text_dict:
for block in text_dict["blocks"]:
text_dict = page.get_text('dict')
if 'blocks' in text_dict:
for block in text_dict['blocks']:
if 'lines' in block:
for line in block["lines"]:
for line in block['lines']:
# 获取line的bbox顶点坐标
x0, y0, x1, y1 = line['bbox']
# 计算bbox的宽高
......@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
if len(font_sizes) > 0:
average_font_size = sum(font_sizes) / len(font_sizes)
else:
average_font_size = 10 # 有的line拿不到font_size,先定一个阈值100
if area <= average_font_size ** 2: # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
average_font_size = (
10 # 有的line拿不到font_size,先定一个阈值100
)
if (
area <= average_font_size**2
): # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
continue
else:
if 'wmode' in line: # 通过wmode判断文本方向
......@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
# 判断每一页的文本布局
if vertical_count == 0 and horizontal_count == 0: # 该页没有文本,无法判断
text_layout_list.append("unknow")
text_layout_list.append('unknow')
continue
else:
if vertical_count > horizontal_count: # 该页的文本纵向行数大于横向的
text_layout_list.append("vertical")
text_layout_list.append('vertical')
else: # 该页的文本横向行数大于纵向的
text_layout_list.append("horizontal")
text_layout_list.append('horizontal')
# logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
return text_layout_list
'''定义一个自定义异常用来抛出单页svg太多的pdf'''
"""定义一个自定义异常用来抛出单页svg太多的pdf"""
class PageSvgsTooManyError(Exception):
def __init__(self, message="Page SVGs are too many"):
def __init__(self, message='Page SVGs are too many'):
self.message = message
super().__init__(self.message)
......@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document):
if page_id >= scan_max_page:
break
# 拿所有text的str
text_block = page.get_text("text")
text_block = page.get_text('text')
page_language = detect_lang(text_block)
language_lst.append(page_language)
......@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document):
def check_invalid_chars(pdf_bytes):
"""
乱码检测
"""
"""乱码检测."""
return detect_invalid_chars(pdf_bytes)
......@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes):
:param pdf_bytes: pdf文件的二进制数据
几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取
"""
doc = fitz.open("pdf", pdf_bytes)
doc = fitz.open('pdf', pdf_bytes)
is_needs_password = doc.needs_pass
is_encrypted = doc.is_encrypted
total_page = len(doc)
if total_page == 0:
logger.warning(f"drop this pdf, drop_reason: {DropReason.EMPTY_PDF}")
result = {"_need_drop": True, "_drop_reason": DropReason.EMPTY_PDF}
logger.warning(f'drop this pdf, drop_reason: {DropReason.EMPTY_PDF}')
result = {'_need_drop': True, '_drop_reason': DropReason.EMPTY_PDF}
return result
else:
page_width_pts, page_height_pts = get_pdf_page_size_pts(doc)
......@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
imgs_per_page = get_imgs_per_page(doc)
# logger.info(f"imgs_per_page: {imgs_per_page}")
image_info_per_page, junk_img_bojids = get_image_info(doc, page_width_pts, page_height_pts)
image_info_per_page, junk_img_bojids = get_image_info(
doc, page_width_pts, page_height_pts
)
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
text_len_per_page = get_pdf_textlen_per_page(doc)
# logger.info(f"text_len_per_page: {text_len_per_page}")
......@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes):
# 最后输出一条json
res = {
"is_needs_password": is_needs_password,
"is_encrypted": is_encrypted,
"total_page": total_page,
"page_width_pts": int(page_width_pts),
"page_height_pts": int(page_height_pts),
"image_info_per_page": image_info_per_page,
"text_len_per_page": text_len_per_page,
"text_layout_per_page": text_layout_per_page,
"text_language": text_language,
'is_needs_password': is_needs_password,
'is_encrypted': is_encrypted,
'total_page': total_page,
'page_width_pts': int(page_width_pts),
'page_height_pts': int(page_height_pts),
'image_info_per_page': image_info_per_page,
'text_len_per_page': text_len_per_page,
'text_layout_per_page': text_layout_per_page,
'text_language': text_language,
# "svgs_per_page": svgs_per_page,
"imgs_per_page": imgs_per_page, # 增加每页img数量list
"junk_img_bojids": junk_img_bojids, # 增加垃圾图片的bojid list
"invalid_chars": invalid_chars,
"metadata": doc.metadata
'imgs_per_page': imgs_per_page, # 增加每页img数量list
'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
'invalid_chars': invalid_chars,
'metadata': doc.metadata,
}
# logger.info(json.dumps(res, ensure_ascii=False))
return res
......@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes):
@click.option('--s3-pdf-path', help='s3上pdf文件的路径')
@click.option('--s3-profile', help='s3上的profile')
def main(s3_pdf_path: str, s3_profile: str):
"""
"""
""""""
try:
file_content = read_file(s3_pdf_path, s3_profile)
pdf_meta_scan(file_content)
except Exception as e:
print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr)
print(f'ERROR: {s3_pdf_path}, {e}', file=sys.stderr)
logger.exception(e)
......@@ -381,7 +403,7 @@ if __name__ == '__main__':
# "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf"
# "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf"
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","")
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","") # noqa: E501
# file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","")
# doc = fitz.open("pdf", file_content)
# text_layout_lst = get_pdf_text_layout_per_page(doc)
......
......@@ -5,14 +5,13 @@ from pathlib import Path
from loguru import logger
import magic_pdf.model as model_config
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
ElementRelation, ElementRelType,
LayoutElements,
LayoutElementsExtra, PageInfo)
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.tools.common import do_parse, prepare_env
......@@ -224,8 +223,8 @@ def inference(path, output_dir, method):
str(Path(path).stem), method)
def read_fn(path):
disk_rw = DiskReaderWriter(os.path.dirname(path))
return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
disk_rw = FileBasedDataReader(os.path.dirname(path))
return disk_rw.read(os.path.basename(path))
def parse_doc(doc_path: str):
try:
......
class MakeMode:
MM_MD = "mm_markdown"
NLP_MD = "nlp_markdown"
STANDARD_FORMAT = "standard_format"
class DropMode:
WHOLE_PDF = "whole_pdf"
SINGLE_PAGE = "single_page"
NONE = "none"
NONE_WITH_REASON = "none_with_reason"
......@@ -5,7 +5,7 @@ import os
from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.commons import parse_bucket_key
# 定义配置文件名常量
......@@ -99,7 +99,7 @@ def get_table_recog_config():
def get_layout_config():
config = read_config()
layout_config = config.get("layout-config")
layout_config = config.get('layout-config')
if layout_config is None:
logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
......@@ -109,7 +109,7 @@ def get_layout_config():
def get_formula_config():
config = read_config()
formula_config = config.get("formula-config")
formula_config = config.get('formula-config')
if formula_config is None:
logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
......@@ -117,5 +117,5 @@ def get_formula_config():
return formula_config
if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw")
if __name__ == '__main__':
ak, sk, endpoint = get_s3_config('llm-raw')
from magic_pdf.config.constants import CROSS_PAGE
from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
ContentType)
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.libs.Constants import CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
from magic_pdf.model.magic_model import MagicModel
......
class DropReason:
TEXT_BLCOK_HOR_OVERLAP = "text_block_horizontal_overlap" # 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP = "useful_block_horizontal_overlap" # 需保留的block水平覆盖
COMPLICATED_LAYOUT = "complicated_layout" # 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS = "too_many_layout_columns" # 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX = "color_background_text_box" # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS = "high_computational_load_by_imgs" # 含特殊图片,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_SVGS = "high_computational_load_by_svgs" # 特殊的SVG图,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = "high_computational_load_by_total_pages" # 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT = "missing doc_layout_result" # 版面分析失败
Exception = "_exception" # 解析中发生异常
ENCRYPTED = "encrypted" # PDF是加密的
EMPTY_PDF = "total_page=0" # PDF页面总数为0
NOT_IS_TEXT_PDF = "not_is_text_pdf" # 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK = "dense_single_line_block" # 无法清晰的分段
TITLE_DETECTION_FAILED = "title_detection_failed" # 探测标题失败
TITLE_LEVEL_FAILED = "title_level_failed" # 分析标题级别失败(例如一级、二级、三级标题)
PARA_SPLIT_FAILED = "para_split_failed" # 识别段落失败
PARA_MERGE_FAILED = "para_merge_failed" # 段落合并失败
NOT_ALLOW_LANGUAGE = "not_allow_language" # 不支持的语种
SPECIAL_PDF = "special_pdf"
PSEUDO_SINGLE_COLUMN = "pseudo_single_column" # 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT="can_not_detect_page_layout" # 无法分析页面的版面
NEGATIVE_BBOX_AREA = "negative_bbox_area" # 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION = "overlap_blocks_can_t_separation" # 无法分离重叠的block
\ No newline at end of file
COLOR_BG_HEADER_TXT_BLOCK = "color_background_header_txt_block"
PAGE_NO = "page-no" # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = "page_no"
HEADER = "header"
FOOTER = "footer"
FOOTNOTE = "footnote"
NOT_IN_LAYOUT = "not_in_layout"
SPAN_OVERLAP = "span_overlap"
BLOCK_OVERLAP = "block_overlap"
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.commons import fitz
from magic_pdf.libs.commons import join_path
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.hash_utils import compute_sha256
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: AbsReaderWriter):
"""
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
"""
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: DataWriter):
"""从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
图片存放在save_path下,文件名是:
{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
# 拼接文件名
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
filename = f'{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}'
# 老版本返回不带bucket的路径
img_path = join_path(return_path, filename) if return_path is not None else None
# 新版本生成平铺路径
img_hash256_path = f"{compute_sha256(img_path)}.jpg"
img_hash256_path = f'{compute_sha256(img_path)}.jpg'
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
......@@ -28,6 +26,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
imageWriter.write(byte_data, img_hash256_path, AbsReaderWriter.MODE_BIN)
imageWriter.write(img_hash256_path, byte_data)
return img_hash256_path
import enum
import json
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
FileBasedDataWriter)
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou,
......@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
......@@ -1050,27 +1050,27 @@ class MagicModel:
if __name__ == '__main__':
drw = DiskReaderWriter(r'D:/project/20231108code-clean')
drw = FileBasedDataReader(r'D:/project/20231108code-clean')
if 0:
pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00_new.json'
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
pdf_bytes = drw.read(pdf_file_path)
model_json_txt = drw.read(model_file_path).decode()
model_list = json.loads(model_json_txt)
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs'
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
img_writer = FileBasedDataWriter(join_path(write_path, img_bucket_path))
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
if 1:
from magic_pdf.data.dataset import PymuDocDataset
model_list = json.loads(
drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
)
pdf_bytes = drw.read(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
)
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
pdf_bytes = drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf')
magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(7):
print(magic_model.get_imgs(i))
import numpy as np
import torch
from loguru import logger
# flake8: noqa
import os
import time
import cv2
import numpy as np
import torch
import yaml
from loguru import logger
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
......@@ -13,20 +15,21 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import torchtext
if torchtext.__version__ >= "0.18.0":
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
from magic_pdf.libs.Constants import *
from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
"""
======== model init ========
......@@ -41,42 +44,54 @@ class CustomPEKModel:
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f:
with open(config_path, 'r', encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置
# layout config
self.layout_config = kwargs.get("layout_config")
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
self.layout_config = kwargs.get('layout_config')
self.layout_model_name = self.layout_config.get(
'model', MODEL_NAME.DocLayout_YOLO
)
# formula config
self.formula_config = kwargs.get("formula_config")
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
self.apply_formula = self.formula_config.get("enable", True)
self.formula_config = kwargs.get('formula_config')
self.mfd_model_name = self.formula_config.get(
'mfd_model', MODEL_NAME.YOLO_V8_MFD
)
self.mfr_model_name = self.formula_config.get(
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
)
self.apply_formula = self.formula_config.get('enable', True)
# table config
self.table_config = kwargs.get("table_config")
self.apply_table = self.table_config.get("enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
# ocr config
self.apply_ocr = ocr
self.lang = kwargs.get("lang", None)
self.lang = kwargs.get('lang', None)
logger.info(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}".format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
self.lang
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
'apply_table: {}, table_model: {}, lang: {}'.format(
self.layout_model_name,
self.apply_formula,
self.apply_ocr,
self.apply_table,
self.table_model_name,
self.lang,
)
)
# 初始化解析方案
self.device = kwargs.get("device", "cpu")
logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
logger.info("using models_dir: {}".format(models_dir))
self.device = kwargs.get('device', 'cpu')
logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get(
'models_dir', os.path.join(root_dir, 'resources', 'models')
)
logger.info('using models_dir: {}'.format(models_dir))
atom_model_manager = AtomModelSingleton()
......@@ -85,18 +100,24 @@ class CustomPEKModel:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
device=self.device
mfd_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.mfd_model_name]
)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
mfr_weight_dir = str(
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
)
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
device=self.device
device=self.device,
)
# 初始化layout模型
......@@ -104,16 +125,28 @@ class CustomPEKModel:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.LAYOUTLMv3,
layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device
layout_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
layout_config_file=str(
os.path.join(
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
)
),
device=self.device,
)
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
device=self.device
doclayout_yolo_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
device=self.device,
)
# 初始化ocr
if self.apply_ocr:
......@@ -121,23 +154,22 @@ class CustomPEKModel:
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
lang=self.lang,
)
# init table model
if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_name]
table_model_dir = self.configs['weights'][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
device=self.device,
)
logger.info('DocAnalysis init done!')
def __call__(self, image):
page_start = time.time()
# layout检测
......@@ -150,7 +182,7 @@ class CustomPEKModel:
# doclayout_yolo
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}")
logger.info(f'layout detection time: {layout_cost}')
pil_img = Image.fromarray(image)
......@@ -158,32 +190,40 @@ class CustomPEKModel:
# 公式检测
mfd_start = time.time()
mfd_res = self.mfd_model.predict(image)
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
# 公式识别
mfr_start = time.time()
formula_list = self.mfr_model.predict(mfd_res, image)
layout_res.extend(formula_list)
mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存
clean_vram(self.device, vram_threshold=8)
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别
if self.apply_ocr:
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
new_image, useful_list = crop_img(
res, pil_img, crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list
)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[
0
]
# Integration results
if ocr_res:
......@@ -191,7 +231,7 @@ class CustomPEKModel:
layout_res.extend(ocr_result_list)
ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr time: {ocr_cost}")
logger.info(f'ocr time: {ocr_cost}')
# 表格识别 table recognition
if self.apply_table:
......@@ -202,27 +242,37 @@ class CustomPEKModel:
html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.table_model.predict(new_image, "html")
table_result = self.table_model.predict(new_image, 'html')
if len(table_result) > 0:
html_code = table_result[0]
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
html_code, table_cell_bboxes, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time
if run_time > self.table_max_time:
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
logger.warning(
f'table recognition processing exceeds max time {self.table_max_time}s'
)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
res["html"] = html_code
res['html'] = html_code
else:
logger.warning(f"table recognition processing fails, not found expected HTML table end")
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(f"table recognition processing fails, not get html return")
logger.info(f"table time: {round(time.time() - table_start, 2)}")
logger.warning(
'table recognition processing fails, not get html return'
)
logger.info(f'table time: {round(time.time() - table_start, 2)}')
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
logger.info(f'-----page total time: {round(time.time() - page_start, 2)}-----')
return layout_res
from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
RapidTableModel
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
......@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = {
"model_dir": model_path,
"device": _device_
'model_dir': model_path,
'device': _device_
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel()
else:
logger.error("table model type not allow")
logger.error('table model type not allow')
exit(1)
return table_model
......@@ -87,8 +92,8 @@ class AtomModelSingleton:
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get("lang", None)
layout_model_name = kwargs.get("layout_model_name", None)
lang = kwargs.get('lang', None)
layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang)
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
......@@ -98,47 +103,47 @@ class AtomModelSingleton:
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init(
kwargs.get("layout_weights"),
kwargs.get("layout_config_file"),
kwargs.get("device")
kwargs.get('layout_weights'),
kwargs.get('layout_config_file'),
kwargs.get('device')
)
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"),
kwargs.get("device")
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get("mfd_weights"),
kwargs.get("device")
kwargs.get('mfd_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get("mfr_weight_dir"),
kwargs.get("mfr_cfg_path"),
kwargs.get("device")
kwargs.get('mfr_weight_dir'),
kwargs.get('mfr_cfg_path'),
kwargs.get('device')
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh"),
kwargs.get("lang")
kwargs.get('ocr_show_log'),
kwargs.get('det_db_box_thresh'),
kwargs.get('lang')
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get("table_model_name"),
kwargs.get("table_model_path"),
kwargs.get("table_max_time"),
kwargs.get("device")
kwargs.get('table_model_name'),
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device')
)
else:
logger.error("model name not allow")
logger.error('model name not allow')
exit(1)
if atom_model is None:
logger.error("model init failed")
logger.error('model init failed')
exit(1)
else:
return atom_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment