Unverified Commit 6d571e2e authored by Kaiwen Liu's avatar Kaiwen Liu Committed by GitHub
Browse files

Merge pull request #7 from opendatalab/dev

Dev
parents a3358878 37c335ae
...@@ -20,6 +20,8 @@ class BlockType: ...@@ -20,6 +20,8 @@ class BlockType:
InterlineEquation = 'interline_equation' InterlineEquation = 'interline_equation'
Footnote = 'footnote' Footnote = 'footnote'
Discarded = 'discarded' Discarded = 'discarded'
List = 'list'
Index = 'index'
class CategoryId: class CategoryId:
......
...@@ -4,7 +4,9 @@ import fitz ...@@ -4,7 +4,9 @@ import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
get_formula_config
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config import magic_pdf.model as model_config
...@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst): ...@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst):
return unique_dicts return unique_dicts
def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
try: try:
from PIL import Image from PIL import Image
except ImportError: except ImportError:
...@@ -32,18 +34,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: ...@@ -32,18 +34,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
images = [] images = []
with fitz.open("pdf", pdf_bytes) as doc: with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count): for index in range(0, doc.page_count):
page = doc[index] if start_page_id <= index <= end_page_id:
mat = fitz.Matrix(dpi / 72, dpi / 72) page = doc[index]
pm = page.get_pixmap(matrix=mat, alpha=False) mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further. img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
if pm.width > 9000 or pm.height > 9000: img = np.array(img)
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) img_dict = {"img": img, "width": pm.width, "height": pm.height}
else:
img_dict = {"img": [], "width": 0, "height": 0}
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height}
images.append(img_dict) images.append(img_dict)
return images return images
...@@ -57,14 +69,17 @@ class ModelSingleton: ...@@ -57,14 +69,17 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, ocr: bool, show_log: bool, lang=None): def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
key = (ocr, show_log, lang) key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models: if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang) self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
formula_enable=formula_enable, table_enable=table_enable)
return self._models[key] return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None): def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
model = None model = None
if model_config.__model_mode__ == "lite": if model_config.__model_mode__ == "lite":
...@@ -84,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None): ...@@ -84,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir() local_models_dir = get_local_models_dir()
device = get_device() device = get_device()
layout_config = get_layout_config()
if layout_model is not None:
layout_config["model"] = layout_model
formula_config = get_formula_config()
if formula_enable is not None:
formula_config["enable"] = formula_enable
table_config = get_table_recog_config() table_config = get_table_recog_config()
model_input = {"ocr": ocr, if table_enable is not None:
"show_log": show_log, table_config["enable"] = table_enable
"models_dir": local_models_dir,
"device": device, model_input = {
"table_config": table_config, "ocr": ocr,
"lang": lang, "show_log": show_log,
} "models_dir": local_models_dir,
"device": device,
"table_config": table_config,
"layout_config": layout_config,
"formula_config": formula_config,
"lang": lang,
}
custom_model = CustomPEKModel(**model_input) custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error("Not allow model_name!")
...@@ -106,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None): ...@@ -106,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
model_manager = ModelSingleton() if lang == "":
custom_model = model_manager.get_model(ocr, show_log, lang) lang = None
images = load_images_from_pdf(pdf_bytes) model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
# end_page_id = end_page_id if end_page_id else len(images) - 1 with fitz.open("pdf", pdf_bytes) as doc:
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(images) - 1 pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
if end_page_id > len(images) - 1: images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
logger.warning("end_page_id is out of range, use images length")
end_page_id = len(images) - 1
model_json = [] model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
...@@ -135,6 +170,11 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, ...@@ -135,6 +170,11 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_dict = {"layout_dets": result, "page_info": page_info} page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict) model_json.append(page_dict)
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
doc_analyze_time = round(time.time() - doc_analyze_start, 2) doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2) doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)}," logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
......
import json import json
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou, bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio, calculate_overlap_area_in_bbox1_area_ratio,
...@@ -9,6 +10,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio ...@@ -9,6 +10,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
...@@ -24,7 +26,7 @@ class MagicModel: ...@@ -24,7 +26,7 @@ class MagicModel:
need_remove_list = [] need_remove_list = []
page_no = model_page_info['page_info']['page_no'] page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio( horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no] model_page_info, self.__docs.get_page(page_no)
) )
layout_dets = model_page_info['layout_dets'] layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets: for layout_det in layout_dets:
...@@ -99,7 +101,7 @@ class MagicModel: ...@@ -99,7 +101,7 @@ class MagicModel:
for need_remove in need_remove_list: for need_remove in need_remove_list:
layout_dets.remove(need_remove) layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: fitz.Document): def __init__(self, model_list: list, docs: Dataset):
self.__model_list = model_list self.__model_list = model_list
self.__docs = docs self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)""" """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
...@@ -119,15 +121,13 @@ class MagicModel: ...@@ -119,15 +121,13 @@ class MagicModel:
if left or right: if left or right:
l1 = bbox1[3] - bbox1[1] l1 = bbox1[3] - bbox1[1]
l2 = bbox2[3] - bbox2[1] l2 = bbox2[3] - bbox2[1]
minL, maxL = min(l1, l2), max(l1, l2) else:
if (maxL - minL) / minL > 0.5:
return float('inf')
if bottom or top:
l1 = bbox1[2] - bbox1[0] l1 = bbox1[2] - bbox1[0]
l2 = bbox2[2] - bbox2[0] l2 = bbox2[2] - bbox2[0]
minL, maxL = min(l1, l2), max(l1, l2)
if (maxL - minL) / minL > 0.5: if l2 > l1 and (l2 - l1) / l1 > 0.3:
return float('inf') return float('inf')
return bbox_distance(bbox1, bbox2) return bbox_distance(bbox1, bbox2)
def __fix_footnote(self): def __fix_footnote(self):
...@@ -215,9 +215,8 @@ class MagicModel: ...@@ -215,9 +215,8 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离 再求出筛选出的 subjects 和 object 的最短距离
""" """
def search_overlap_between_boxes(
subject_idx, object_idx def search_overlap_between_boxes(subject_idx, object_idx):
):
idxes = [subject_idx, object_idx] idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes] x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes] y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
...@@ -245,9 +244,9 @@ class MagicModel: ...@@ -245,9 +244,9 @@ class MagicModel:
for other_object in other_objects: for other_object in other_objects:
ratio = max( ratio = max(
ratio, ratio,
get_overlap_area( get_overlap_area(merged_bbox, other_object['bbox'])
merged_bbox, other_object['bbox'] * 1.0
) * 1.0 / box_area(all_bboxes[object_idx]['bbox']) / box_area(all_bboxes[object_idx]['bbox']),
) )
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO: if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break break
...@@ -365,12 +364,17 @@ class MagicModel: ...@@ -365,12 +364,17 @@ class MagicModel:
if all_bboxes[j]['category_id'] == subject_category_id: if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO: if (
search_overlap_between_boxes(subject_idx, object_idx)
>= MERGE_BOX_OVERLAP_AREA_RATIO
):
dis[i][j] = float('inf') dis[i][j] = float('inf')
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
continue continue
dis[i][j] = self._bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox']) dis[i][j] = self._bbox_distance(
all_bboxes[subject_idx]['bbox'], all_bboxes[object_idx]['bbox']
)
dis[j][i] = dis[i][j] dis[j][i] = dis[i][j]
used = set() used = set()
...@@ -461,7 +465,7 @@ class MagicModel: ...@@ -461,7 +465,7 @@ class MagicModel:
if is_nearest: if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k]) nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = self._bbox_distance( n_dis = bbox_distance(
all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1] all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
) )
if float_gt(dis[i][j], n_dis): if float_gt(dis[i][j], n_dis):
...@@ -557,7 +561,7 @@ class MagicModel: ...@@ -557,7 +561,7 @@ class MagicModel:
# 计算已经配对的 distance 距离 # 计算已经配对的 distance 距离
for i in subject_object_relation_map.keys(): for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]: for j in subject_object_relation_map[i]:
total_subject_object_dis += self._bbox_distance( total_subject_object_dis += bbox_distance(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox'] all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
) )
...@@ -586,6 +590,245 @@ class MagicModel: ...@@ -586,6 +590,245 @@ class MagicModel:
with_caption_subject.add(j) with_caption_subject.add(j)
return ret, total_subject_object_dis return ret, total_subject_object_dis
def __tie_up_category_by_distance_v2(
self, page_no, subject_category_id, object_category_id
):
AXIS_MULPLICITY = 0.5
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
M = len(objects)
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
sub_obj_map_h = {i: [] for i in range(len(subjects))}
dis_by_directions = {
'top': [[-1, float('inf')]] * M,
'bottom': [[-1, float('inf')]] * M,
'left': [[-1, float('inf')]] * M,
'right': [[-1, float('inf')]] * M,
}
for i, obj in enumerate(objects):
l_x_axis, l_y_axis = (
obj['bbox'][2] - obj['bbox'][0],
obj['bbox'][3] - obj['bbox'][1],
)
axis_unit = min(l_x_axis, l_y_axis)
for j, sub in enumerate(subjects):
bbox1, bbox2, _ = _remove_overlap_between_bbox(
objects[i]['bbox'], subjects[j]['bbox']
)
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
flags = [left, right, bottom, top]
if sum([1 if v else 0 for v in flags]) > 1:
continue
if left:
if dis_by_directions['left'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['left'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if right:
if dis_by_directions['right'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['right'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if bottom:
if dis_by_directions['bottom'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['bottom'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if top:
if dis_by_directions['top'][i][1] > bbox_distance(
obj['bbox'], sub['bbox']
):
dis_by_directions['top'][i] = [
j,
bbox_distance(obj['bbox'], sub['bbox']),
]
if dis_by_directions['left'][i][1] != float('inf') or dis_by_directions[
'right'
][i][1] != float('inf'):
if dis_by_directions['left'][i][1] != float(
'inf'
) and dis_by_directions['right'][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['left'][i][1]
- dis_by_directions['right'][i][1]
):
left_sub_bbox = subjects[dis_by_directions['left'][i][0]][
'bbox'
]
right_sub_bbox = subjects[dis_by_directions['right'][i][0]][
'bbox'
]
left_sub_bbox_y_axis = left_sub_bbox[3] - left_sub_bbox[1]
right_sub_bbox_y_axis = right_sub_bbox[3] - right_sub_bbox[1]
if (
abs(left_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['left'][i][0]
> abs(right_sub_bbox_y_axis - l_y_axis)
+ dis_by_directions['right'][i][0]
):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] > dis_by_directions['right'][i][1]:
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = dis_by_directions['left'][i]
if left_or_right[1] == float('inf'):
left_or_right = dis_by_directions['right'][i]
else:
left_or_right = [-1, float('inf')]
if dis_by_directions['top'][i][1] != float('inf') or dis_by_directions[
'bottom'
][i][1] != float('inf'):
if dis_by_directions['top'][i][1] != float('inf') and dis_by_directions[
'bottom'
][i][1] != float('inf'):
if AXIS_MULPLICITY * axis_unit >= abs(
dis_by_directions['top'][i][1]
- dis_by_directions['bottom'][i][1]
):
top_bottom = subjects[dis_by_directions['bottom'][i][0]]['bbox']
bottom_top = subjects[dis_by_directions['top'][i][0]]['bbox']
top_bottom_x_axis = top_bottom[2] - top_bottom[0]
bottom_top_x_axis = bottom_top[2] - bottom_top[0]
if abs(top_bottom_x_axis - l_x_axis) + dis_by_directions['bottom'][i][1] > abs(
bottom_top_x_axis - l_x_axis
) + dis_by_directions['top'][i][1]:
top_or_bottom = dis_by_directions['top'][i]
else:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] > dis_by_directions['bottom'][i][1]:
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = dis_by_directions['top'][i]
if top_or_bottom[1] == float('inf'):
top_or_bottom = dis_by_directions['bottom'][i]
else:
top_or_bottom = [-1, float('inf')]
if left_or_right[1] != float('inf') or top_or_bottom[1] != float('inf'):
if left_or_right[1] != float('inf') and top_or_bottom[1] != float(
'inf'
):
if AXIS_MULPLICITY * axis_unit >= abs(
left_or_right[1] - top_or_bottom[1]
):
y_axis_bbox = subjects[left_or_right[0]]['bbox']
x_axis_bbox = subjects[top_or_bottom[0]]['bbox']
if (
abs((x_axis_bbox[2] - x_axis_bbox[0]) - l_x_axis) / l_x_axis
> abs((y_axis_bbox[3] - y_axis_bbox[1]) - l_y_axis)
/ l_y_axis
):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
if left_or_right[1] > top_or_bottom[1]:
sub_obj_map_h[top_or_bottom[0]].append(i)
else:
sub_obj_map_h[left_or_right[0]].append(i)
else:
if left_or_right[1] != float('inf'):
sub_obj_map_h[left_or_right[0]].append(i)
else:
sub_obj_map_h[top_or_bottom[0]].append(i)
ret = []
for i in sub_obj_map_h.keys():
ret.append(
{
'sub_bbox': {
'bbox': subjects[i]['bbox'],
'score': subjects[i]['score'],
},
'obj_bboxes': [
{'score': objects[j]['score'], 'bbox': objects[j]['bbox']}
for j in sub_obj_map_h[i]
],
'sub_idx': i,
}
)
return ret
def get_imgs_v2(self, page_no: int):
with_captions = self.__tie_up_category_by_distance_v2(page_no, 3, 4)
with_footnotes = self.__tie_up_category_by_distance_v2(
page_no, 3, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
record = {
'image_body': v['sub_bbox'],
'image_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['image_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_tables_v2(self, page_no: int) -> list:
with_captions = self.__tie_up_category_by_distance_v2(page_no, 5, 6)
with_footnotes = self.__tie_up_category_by_distance_v2(page_no, 5, 7)
ret = []
for v in with_captions:
record = {
'table_body': v['sub_bbox'],
'table_caption_list': v['obj_bboxes'],
}
filter_idx = v['sub_idx']
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
record['table_footnote_list'] = d['obj_bboxes']
ret.append(record)
return ret
def get_imgs(self, page_no: int): def get_imgs(self, page_no: int):
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4) with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
with_footnotes, _ = self.__tie_up_category_by_distance( with_footnotes, _ = self.__tie_up_category_by_distance(
...@@ -719,10 +962,10 @@ class MagicModel: ...@@ -719,10 +962,10 @@ class MagicModel:
def get_page_size(self, page_no: int): # 获取页面宽高 def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象 # 获取当前页的page对象
page = self.__docs[page_no] page = self.__docs.get_page(page_no).get_page_info()
# 获取当前页的宽高 # 获取当前页的宽高
page_w = page.rect.width page_w = page.w
page_h = page.rect.height page_h = page.h
return page_w, page_h return page_w, page_h
def __get_blocks_by_type( def __get_blocks_by_type(
......
...@@ -26,6 +26,7 @@ try: ...@@ -26,6 +26,7 @@ try:
from unimernet.common.config import Config from unimernet.common.config import Config
import unimernet.tasks as tasks import unimernet.tasks as tasks
from unimernet.processors import load_processor from unimernet.processors import load_processor
from doclayout_yolo import YOLOv10
except ImportError as e: except ImportError as e:
logger.exception(e) logger.exception(e)
...@@ -42,7 +43,7 @@ from magic_pdf.model.ppTableModel import ppTableModel ...@@ -42,7 +43,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
if table_model_type == STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_time=max_time, device=_device_) table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
else: else:
config = { config = {
...@@ -83,11 +84,16 @@ def layout_model_init(weight, config_file, device): ...@@ -83,11 +84,16 @@ def layout_model_init(weight, config_file, device):
return model return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None): def doclayout_yolo_model_init(weight):
model = YOLOv10(weight)
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
if lang is not None: if lang is not None:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang) model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
else: else:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh) model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
return model return model
...@@ -120,19 +126,27 @@ class AtomModelSingleton: ...@@ -120,19 +126,27 @@ class AtomModelSingleton:
return cls._instance return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs): def get_atom_model(self, atom_model_name: str, **kwargs):
if atom_model_name not in self._models: lang = kwargs.get("lang", None)
self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs) layout_model_name = kwargs.get("layout_model_name", None)
return self._models[atom_model_name] key = (atom_model_name, layout_model_name, lang)
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs): def atom_model_init(model_name: str, **kwargs):
if model_name == AtomicModel.Layout: if model_name == AtomicModel.Layout:
atom_model = layout_model_init( if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
kwargs.get("layout_weights"), atom_model = layout_model_init(
kwargs.get("layout_config_file"), kwargs.get("layout_weights"),
kwargs.get("device") kwargs.get("layout_config_file"),
) kwargs.get("device")
)
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"),
)
elif model_name == AtomicModel.MFD: elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init( atom_model = mfd_model_init(
kwargs.get("mfd_weights") kwargs.get("mfd_weights")
...@@ -151,7 +165,7 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -151,7 +165,7 @@ def atom_model_init(model_name: str, **kwargs):
) )
elif model_name == AtomicModel.Table: elif model_name == AtomicModel.Table:
atom_model = table_model_init( atom_model = table_model_init(
kwargs.get("table_model_type"), kwargs.get("table_model_name"),
kwargs.get("table_model_path"), kwargs.get("table_model_path"),
kwargs.get("table_max_time"), kwargs.get("table_max_time"),
kwargs.get("device") kwargs.get("device")
...@@ -199,23 +213,35 @@ class CustomPEKModel: ...@@ -199,23 +213,35 @@ class CustomPEKModel:
with open(config_path, "r", encoding='utf-8') as f: with open(config_path, "r", encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader) self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置 # 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"]) # layout config
self.layout_config = kwargs.get("layout_config")
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
# formula config
self.formula_config = kwargs.get("formula_config")
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
self.apply_formula = self.formula_config.get("enable", True)
# table config # table config
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"]) self.table_config = kwargs.get("table_config")
self.apply_table = self.table_config.get("is_table_recog_enable", False) self.apply_table = self.table_config.get("enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER) self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
# ocr config
self.apply_ocr = ocr self.apply_ocr = ocr
self.lang = kwargs.get("lang", None) self.lang = kwargs.get("lang", None)
logger.info( logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format( "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang "apply_table: {}, table_model: {}, lang: {}".format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
) )
) )
assert self.apply_layout, "DocAnalysis must contain layout model."
# 初始化解析方案 # 初始化解析方案
self.device = kwargs.get("device", self.configs["config"]["device"]) self.device = kwargs.get("device", "cpu")
logger.info("using device: {}".format(self.device)) logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models")) models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
logger.info("using models_dir: {}".format(models_dir)) logger.info("using models_dir: {}".format(models_dir))
...@@ -224,17 +250,16 @@ class CustomPEKModel: ...@@ -224,17 +250,16 @@ class CustomPEKModel:
# 初始化公式识别 # 初始化公式识别
if self.apply_formula: if self.apply_formula:
# 初始化公式检测模型 # 初始化公式检测模型
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self.mfd_model = atom_model_manager.get_atom_model( self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD, atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"])) mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
) )
# 初始化公式解析模型 # 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"])) mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml")) mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model( self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
...@@ -243,17 +268,20 @@ class CustomPEKModel: ...@@ -243,17 +268,20 @@ class CustomPEKModel:
) )
# 初始化layout模型 # 初始化layout模型
# self.layout_model = Layoutlmv3_Predictor( if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# str(os.path.join(models_dir, self.configs['weights']['layout'])), self.layout_model = atom_model_manager.get_atom_model(
# str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), atom_model_name=AtomicModel.Layout,
# device=self.device layout_model_name=MODEL_NAME.LAYOUTLMv3,
# ) layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
self.layout_model = atom_model_manager.get_atom_model( layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
atom_model_name=AtomicModel.Layout, device=self.device
layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])), )
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
device=self.device self.layout_model = atom_model_manager.get_atom_model(
) atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
)
# 初始化ocr # 初始化ocr
if self.apply_ocr: if self.apply_ocr:
...@@ -266,12 +294,10 @@ class CustomPEKModel: ...@@ -266,12 +294,10 @@ class CustomPEKModel:
) )
# init table model # init table model
if self.apply_table: if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_type] table_model_dir = self.configs["weights"][self.table_model_name]
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
# max_time=self.table_max_time, _device_=self.device)
self.table_model = atom_model_manager.get_atom_model( self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table, atom_model_name=AtomicModel.Table,
table_model_type=self.table_model_type, table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)), table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device device=self.device
...@@ -288,7 +314,21 @@ class CustomPEKModel: ...@@ -288,7 +314,21 @@ class CustomPEKModel:
# layout检测 # layout检测
layout_start = time.time() layout_start = time.time()
layout_res = self.layout_model(image, ignore_catids=[]) if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
layout_res = []
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3),
}
layout_res.append(new_item)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}") logger.info(f"layout detection time: {layout_cost}")
...@@ -297,7 +337,7 @@ class CustomPEKModel: ...@@ -297,7 +337,7 @@ class CustomPEKModel:
if self.apply_formula: if self.apply_formula:
# 公式检测 # 公式检测
mfd_start = time.time() mfd_start = time.time()
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0] mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}") logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
...@@ -309,7 +349,6 @@ class CustomPEKModel: ...@@ -309,7 +349,6 @@ class CustomPEKModel:
} }
layout_res.append(new_item) layout_res.append(new_item)
latex_filling_list.append(new_item) latex_filling_list.append(new_item)
# bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img) mf_image_list.append(bbox_img)
...@@ -346,7 +385,7 @@ class CustomPEKModel: ...@@ -346,7 +385,7 @@ class CustomPEKModel:
if torch.cuda.is_available(): if torch.cuda.is_available():
properties = torch.cuda.get_device_properties(self.device) properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 8: if total_memory <= 10:
gc_start = time.time() gc_start = time.time()
clean_memory() clean_memory()
gc_time = round(time.time() - gc_start, 2) gc_time = round(time.time() - gc_start, 2)
...@@ -411,7 +450,7 @@ class CustomPEKModel: ...@@ -411,7 +450,7 @@ class CustomPEKModel:
# logger.info("------------------table recognition processing begins-----------------") # logger.info("------------------table recognition processing begins-----------------")
latex_code = None latex_code = None
html_code = None html_code = None
if self.table_model_type == STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad(): with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0] latex_code = self.table_model.image2latex(new_image)[0]
else: else:
......
...@@ -52,11 +52,11 @@ class ppTableModel(object): ...@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
device = kwargs.get("device", "cpu") device = kwargs.get("device", "cpu")
use_gpu = True if device == "cuda" else False use_gpu = True if device.startswith("cuda") else False
config = { config = {
"use_gpu": use_gpu, "use_gpu": use_gpu,
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN), "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
"table_algorithm": TABLE_MASTER, "table_algorithm": "TableMaster",
"table_model_dir": table_model_dir, "table_model_dir": table_model_dir,
"table_char_dict_path": table_char_dict_path, "table_char_dict_path": table_char_dict_path,
"det_model_dir": det_model_dir, "det_model_dir": det_model_dir,
......
import copy
from loguru import logger
from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
LIST_END_FLAG = ('.', '。', ';', ';')
class ListLineTag:
IS_LIST_START_LINE = "is_list_start_line"
IS_LIST_END_LINE = "is_list_end_line"
def __process_blocks(blocks):
# 对所有block预处理
# 1.通过title和interline_equation将block分组
# 2.bbox边界根据line信息重置
result = []
current_group = []
for i in range(len(blocks)):
current_block = blocks[i]
# 如果当前块是 text 类型
if current_block['type'] == 'text':
current_block["bbox_fs"] = copy.deepcopy(current_block["bbox"])
if 'lines' in current_block and len(current_block["lines"]) > 0:
current_block['bbox_fs'] = [min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']])]
current_group.append(current_block)
# 检查下一个块是否存在
if i + 1 < len(blocks):
next_block = blocks[i + 1]
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
if next_block['type'] in ['title', 'interline_equation']:
result.append(current_group)
current_group = []
# 处理最后一个 group
if current_group:
result.append(current_group)
return result
def __is_list_or_index_block(block):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
# index block 是一种特殊的list block
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if len(block['lines']) >= 2:
first_line = block['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
left_close_num = 0
left_not_close_num = 0
right_not_close_num = 0
right_close_num = 0
lines_text_list = []
multiple_para_flag = False
last_line = block['lines'][-1]
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and
# block['bbox_fs'][2] - first_line['bbox'][2] < line_height and
abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and
block['bbox_fs'][2] - last_line['bbox'][2] > line_height
):
multiple_para_flag = True
for line in block['lines']:
line_text = ""
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.Text:
line_text += span['content'].strip()
lines_text_list.append(line_text)
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
left_close_num += 1
elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
# logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
left_not_close_num += 1
# 计算右侧是否顶格
if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
right_close_num += 1
else:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
closed_area = 0.3 * block_weight
# closed_area = 5 * line_height
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag = False
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag = False
num_start_count = 0
num_end_count = 0
flag_end_count = 0
if len(lines_text_list) > 0:
for line_text in lines_text_list:
if len(line_text) > 0:
if line_text[-1] in LIST_END_FLAG:
flag_end_count += 1
if line_text[0].isdigit():
num_start_count += 1
if line_text[-1].isdigit():
num_end_count += 1
if flag_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True
if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
line_num_flag = True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if ((left_close_num/len(block['lines']) >= 0.8 or right_close_num/len(block['lines']) >= 0.8)
and line_num_flag
):
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.Index
elif left_close_num >= 2 and (
right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2) and not multiple_para_flag:
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if left_close_num / len(block['lines']) > 0.9:
# 这种是每个item只有一行,且左边都贴边的短item list
if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif line_end_flag:
for i, line in enumerate(block['lines']):
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
if i + 1 < len(block['lines']):
block['lines'][i+1][ListLineTag.IS_LIST_START_LINE] = True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else:
line_start_flag = False
for i, line in enumerate(block['lines']):
if line_start_flag:
line[ListLineTag.IS_LIST_START_LINE] = True
line_start_flag = False
elif abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
line_start_flag = True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_LINE 结尾且数量和start line 一致
elif num_start_count >= 2 and num_start_count == flag_end_count: # 简单一点先不考虑左侧不贴边的情况
for i, line in enumerate(block['lines']):
if lines_text_list[i][0].isdigit():
line[ListLineTag.IS_LIST_START_LINE] = True
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
else:
# 正常有缩进的list处理
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
return BlockType.List
else:
return BlockType.Text
else:
return BlockType.Text
def __merge_2_text_blocks(block1, block2):
if len(block1['lines']) > 0:
first_line = block1['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block1_weight = block1['bbox'][2] - block1['bbox'][0]
block2_weight = block2['bbox'][2] - block2['bbox'][0]
min_block_weight = min(block1_weight, block2_weight)
if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
last_line = block2['lines'][-1]
if len(last_line['spans']) > 0:
last_span = last_line['spans'][-1]
line_height = last_line['bbox'][3] - last_line['bbox'][1]
if (abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height and
not last_span['content'].endswith(LINE_STOP_FLAG) and
# 两个block宽度差距超过2倍也不合并
abs(block1_weight - block2_weight) < min_block_weight
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
return block1, block2
def __merge_2_list_blocks(block1, block2):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
return block1, block2
def __is_list_group(text_blocks_group):
# list group的特征是一个group内的所有block都满足以下条件
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
for block in text_blocks_group:
if len(block['lines']) > 3:
return False
return True
def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups:
if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list or index block
for block in text_blocks_group:
block_type = __is_list_or_index_block(block)
block['type'] = block_type
# logger.info(f"{block['type']}:{block}")
if len(text_blocks_group) > 1:
# 在合并前判断这个group 是否是一个 list group
is_list_group = __is_list_group(text_blocks_group)
# 倒序遍历
for i in range(len(text_blocks_group) - 1, -1, -1):
current_block = text_blocks_group[i]
# 检查是否有前一个块
if i - 1 >= 0:
prev_block = text_blocks_group[i - 1]
if current_block['type'] == 'text' and prev_block['type'] == 'text' and not is_list_group:
__merge_2_text_blocks(current_block, prev_block)
elif (
(current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List) or
(current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index)
):
__merge_2_list_blocks(current_block, prev_block)
else:
continue
def para_split(pdf_info_dict, debug_mode=False):
all_blocks = []
for page_num, page in pdf_info_dict.items():
blocks = copy.deepcopy(page['preproc_blocks'])
for block in blocks:
block['page_num'] = page_num
all_blocks.extend(blocks)
__para_merge_page(all_blocks)
for page_num, page in pdf_info_dict.items():
page['para_blocks'] = []
for block in all_blocks:
if block['page_num'] == page_num:
page['para_blocks'].append(block)
if __name__ == '__main__':
input_blocks = []
# 调用函数
groups = __process_blocks(input_blocks)
for group_index, group in enumerate(groups):
print(f"Group {group_index}: {group}")
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"ocr", SupportedPdfParseMethod.OCR,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
...@@ -9,10 +11,11 @@ def parse_pdf_by_txt( ...@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"txt", SupportedPdfParseMethod.TXT,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
import copy
import os
import statistics import statistics
import time import time
from loguru import logger
from typing import List from typing import List
import torch import torch
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.local_math import float_equal from magic_pdf.libs.local_math import float_equal
from magic_pdf.libs.ocr_content_type import ContentType from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import \
ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \ from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict combine_chars_to_pymudict, remove_chars_in_text_blocks,
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2 replace_equations_in_textblock)
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans, fix_discarded_block from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \ ocr_prepare_bboxes_for_layout_split_v2
remove_overlaps_low_confidence_spans from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap fix_block_spans,
fix_discarded_block, fix_block_spans_v2)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
from magic_pdf.pre_proc.resolve_bbox_conflict import \
check_useful_block_horizontal_overlap
def remove_horizontal_overlap_block_which_smaller(all_bboxes): def remove_horizontal_overlap_block_which_smaller(all_bboxes):
useful_blocks = [] useful_blocks = []
for bbox in all_bboxes: for bbox in all_bboxes:
useful_blocks.append({ useful_blocks.append({'bbox': bbox[:4]})
"bbox": bbox[:4] is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
}) check_useful_block_horizontal_overlap(useful_blocks)
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks) )
if is_useful_block_horz_overlap: if is_useful_block_horz_overlap:
logger.warning( logger.warning(
f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}") f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
) # noqa: E501
for bbox in all_bboxes.copy(): for bbox in all_bboxes.copy():
if smaller_bbox == bbox[:4]: if smaller_bbox == bbox[:4]:
all_bboxes.remove(bbox) all_bboxes.remove(bbox)
...@@ -44,27 +56,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes): ...@@ -44,27 +56,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
return is_useful_block_horz_overlap, all_bboxes return is_useful_block_horz_overlap, all_bboxes
def __replace_STX_ETX(text_str:str): def __replace_STX_ETX(text_str: str):
""" Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks. """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far. Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args: Args:
text_str (str): raw text text_str (str): raw text
Returns: Returns:
_type_: replaced text _type_: replaced text
""" """ # noqa: E501
if text_str: if text_str:
s = text_str.replace('\u0002', "'") s = text_str.replace('\u0002', "'")
s = s.replace("\u0003", "'") s = s.replace('\u0003', "'")
return s return s
return text_str return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations): def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"] text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[ char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
"blocks" 'blocks'
] ]
text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks) text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
text_blocks = replace_equations_in_textblock( text_blocks = replace_equations_in_textblock(
...@@ -74,50 +86,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations): ...@@ -74,50 +86,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_blocks = remove_chars_in_text_blocks(text_blocks) text_blocks = remove_chars_in_text_blocks(text_blocks)
spans = [] spans = []
for v in text_blocks: for v in text_blocks:
for line in v["lines"]: for line in v['lines']:
for span in line["spans"]: for span in line['spans']:
bbox = span["bbox"] bbox = span['bbox']
if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]): if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
continue continue
if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation): if span.get('type') not in (
ContentType.InlineEquation,
ContentType.InterlineEquation,
):
spans.append( spans.append(
{ {
"bbox": list(span["bbox"]), 'bbox': list(span['bbox']),
"content": __replace_STX_ETX(span["text"]), 'content': __replace_STX_ETX(span['text']),
"type": ContentType.Text, 'type': ContentType.Text,
"score": 1.0, 'score': 1.0,
} }
) )
return spans return spans
def replace_text_span(pymu_spans, ocr_spans): def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str, local_path=None): def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device('cuda')
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
supports_bfloat16 = True supports_bfloat16 = True
else: else:
supports_bfloat16 = False supports_bfloat16 = False
else: else:
device = torch.device("cpu") device = torch.device('cpu')
supports_bfloat16 = False supports_bfloat16 = False
if model_name == "layoutreader": if model_name == 'layoutreader':
if local_path: # 检测modelscope的缓存目录是否存在
model = LayoutLMv3ForTokenClassification.from_pretrained(local_path) layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
)
else: else:
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader") logger.warning(
'local layoutreader model not exists, use online model from huggingface'
)
model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader'
)
# 检查设备是否支持 bfloat16 # 检查设备是否支持 bfloat16
if supports_bfloat16: if supports_bfloat16:
model.bfloat16() model.bfloat16()
model.to(device).eval() model.to(device).eval()
else: else:
logger.error("model name not allow") logger.error('model name not allow')
exit(1) exit(1)
return model return model
...@@ -131,17 +156,16 @@ class ModelSingleton: ...@@ -131,17 +156,16 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, model_name: str, local_path=None): def get_model(self, model_name: str):
if model_name not in self._models: if model_name not in self._models:
if local_path: self._models[model_name] = model_init(model_name=model_name)
self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
else:
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name] return self._models[model_name]
def do_predict(boxes: List[List[int]], model) -> List[int]: def do_predict(boxes: List[List[int]], model) -> List[int]:
from magic_pdf.model.v3.helpers import prepare_inputs, boxes2inputs, parse_logits from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
prepare_inputs)
inputs = boxes2inputs(boxes) inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model) inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0) logits = model(**inputs).logits.cpu().squeeze(0)
...@@ -150,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]: ...@@ -150,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
def cal_block_index(fix_blocks, sorted_bboxes): def cal_block_index(fix_blocks, sorted_bboxes):
for block in fix_blocks: for block in fix_blocks:
# if block['type'] in ['text', 'title', 'interline_equation']:
# line_index_list = []
# if len(block['lines']) == 0:
# block['index'] = sorted_bboxes.index(block['bbox'])
# else:
# for line in block['lines']:
# line['index'] = sorted_bboxes.index(line['bbox'])
# line_index_list.append(line['index'])
# median_value = statistics.median(line_index_list)
# block['index'] = median_value
#
# elif block['type'] in ['table', 'image']:
# block['index'] = sorted_bboxes.index(block['bbox'])
line_index_list = [] line_index_list = []
if len(block['lines']) == 0: if len(block['lines']) == 0:
...@@ -174,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes): ...@@ -174,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
median_value = statistics.median(line_index_list) median_value = statistics.median(line_index_list)
block['index'] = median_value block['index'] = median_value
# 删除图表block中的虚拟line信息 # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in ['table', 'image']: if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
del block['lines'] block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
return fix_blocks return fix_blocks
...@@ -189,21 +202,22 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -189,21 +202,22 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
block_weight = x1 - x0 block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox # 如果block高度小于n行正文,则直接返回block的bbox
if line_height*3 < block_height: if line_height * 3 < block_height:
if block_height > page_h*0.25 and page_w*0.5 > block_weight > page_w*0.25: # 可能是双列结构,可以切细点 if (
lines = int(block_height/line_height)+1 block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
lines = int(block_height / line_height) + 1
else: else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行 # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w*0.4: if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3 line_height = (y1 - y0) / 3
lines = 3 lines = 3
elif block_weight > page_w*0.25: # 否则将block分成两行 elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
line_height = (y1 - y0) / 2 lines = int(block_height / line_height) + 1
lines = 2 else: # 判断长宽比
else: # 判断长宽比 if block_height / block_weight > 1.2: # 细长的不分
if block_height/block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]] return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行 else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2 line_height = (y1 - y0) / 2
lines = 2 lines = 2
...@@ -225,7 +239,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -225,7 +239,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = [] page_line_list = []
for block in fix_blocks: for block in fix_blocks:
if block['type'] in ['text', 'title', 'interline_equation']: if block['type'] in [
BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
if len(block['lines']) == 0: if len(block['lines']) == 0:
bbox = block['bbox'] bbox = block['bbox']
lines = insert_lines_into_block(bbox, line_height, page_w, page_h) lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
...@@ -236,8 +254,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -236,8 +254,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
page_line_list.append(bbox) page_line_list.append(bbox)
elif block['type'] in ['table', 'image']: elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
bbox = block['bbox'] bbox = block['bbox']
block["real_lines"] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h) lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
block['lines'] = [] block['lines'] = []
for line in lines: for line in lines:
...@@ -252,19 +271,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -252,19 +271,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
for left, top, right, bottom in page_line_list: for left, top, right, bottom in page_line_list:
if left < 0: if left < 0:
logger.warning( logger.warning(
f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
left = 0 left = 0
if right > page_w: if right > page_w:
logger.warning( logger.warning(
f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
right = page_w right = page_w
if top < 0: if top < 0:
logger.warning( logger.warning(
f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
top = 0 top = 0
if bottom > page_h: if bottom > page_h:
logger.warning( logger.warning(
f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
bottom = page_h bottom = page_h
left = round(left * x_scale) left = round(left * x_scale)
...@@ -272,11 +295,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -272,11 +295,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
right = round(right * x_scale) right = round(right * x_scale)
bottom = round(bottom * y_scale) bottom = round(bottom * y_scale)
assert ( assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0 1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}" ), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}' # noqa: E126, E121
boxes.append([left, top, right, bottom]) boxes.append([left, top, right, bottom])
model_manager = ModelSingleton() model_manager = ModelSingleton()
model = model_manager.get_model("layoutreader") model = model_manager.get_model('layoutreader')
with torch.no_grad(): with torch.no_grad():
orders = do_predict(boxes, model) orders = do_predict(boxes, model)
sorted_bboxes = [page_line_list[i] for i in orders] sorted_bboxes = [page_line_list[i] for i in orders]
...@@ -287,159 +310,282 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -287,159 +310,282 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
def get_line_height(blocks): def get_line_height(blocks):
page_line_height_list = [] page_line_height_list = []
for block in blocks: for block in blocks:
if block['type'] in ['text', 'title', 'interline_equation']: if block['type'] in [
BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
page_line_height_list.append(int(bbox[3]-bbox[1])) page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0: if len(page_line_height_list) > 0:
return statistics.median(page_line_height_list) return statistics.median(page_line_height_list)
else: else:
return 10 return 10
def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode): def process_groups(groups, body_key, caption_key, footnote_key):
body_blocks = []
caption_blocks = []
footnote_blocks = []
for i, group in enumerate(groups):
group[body_key]['group_id'] = i
body_blocks.append(group[body_key])
for caption_block in group[caption_key]:
caption_block['group_id'] = i
caption_blocks.append(caption_block)
for footnote_block in group[footnote_key]:
footnote_block['group_id'] = i
footnote_blocks.append(footnote_block)
return body_blocks, caption_blocks, footnote_blocks
def process_block_list(blocks, body_type, block_type):
indices = [block['index'] for block in blocks]
median_index = statistics.median(indices)
body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
return {
'type': block_type,
'bbox': body_bbox,
'blocks': blocks,
'index': median_index,
}
def revert_group_blocks(blocks):
image_groups = {}
table_groups = {}
new_blocks = []
for block in blocks:
if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
group_id = block['group_id']
if group_id not in image_groups:
image_groups[group_id] = []
image_groups[group_id].append(block)
elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
group_id = block['group_id']
if group_id not in table_groups:
table_groups[group_id] = []
table_groups[group_id].append(block)
else:
new_blocks.append(block)
for group_id, blocks in image_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
for group_id, blocks in table_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
return new_blocks
def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
):
need_drop = False need_drop = False
drop_reason = [] drop_reason = []
'''从magic_model对象中获取后面会用到的区块信息''' """从magic_model对象中获取后面会用到的区块信息"""
img_blocks = magic_model.get_imgs(page_id) # img_blocks = magic_model.get_imgs(page_id)
table_blocks = magic_model.get_tables(page_id) # table_blocks = magic_model.get_tables(page_id)
img_groups = magic_model.get_imgs_v2(page_id)
table_groups = magic_model.get_tables_v2(page_id)
img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
)
discarded_blocks = magic_model.get_discarded(page_id) discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id) text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id) title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id) inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id) page_w, page_h = magic_model.get_page_size(page_id)
spans = magic_model.get_all_spans(page_id) spans = magic_model.get_all_spans(page_id)
'''根据parse_mode,构造spans''' """根据parse_mode,构造spans"""
if parse_mode == "txt": if parse_mode == SupportedPdfParseMethod.TXT:
"""ocr 中文本类的 span 用 pymu spans 替换!""" """ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract( pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
pdf_docs[page_id], inline_equations, interline_equations
)
spans = replace_text_span(pymu_spans, spans) spans = replace_text_span(pymu_spans, spans)
elif parse_mode == "ocr": elif parse_mode == SupportedPdfParseMethod.OCR:
pass pass
else: else:
raise Exception("parse_mode must be txt or ocr") raise Exception('parse_mode must be txt or ocr')
'''删除重叠spans中置信度较低的那些''' """删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans) spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些''' """删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans) spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图''' """对image和table截图"""
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter) spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
'''将所有区块的bbox整理到一起''' """将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上 # interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = [] interline_equation_blocks = []
if len(interline_equation_blocks) > 0: if len(interline_equation_blocks) > 0:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2( all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, img_body_blocks, img_caption_blocks, img_footnote_blocks,
interline_equation_blocks, page_w, page_h) table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else: else:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2( all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, img_body_blocks, img_caption_blocks, img_footnote_blocks,
interline_equations, page_w, page_h) table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
'''先处理不需要排版的discarded_blocks''' """先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4) discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans) fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
'''如果当前页面没有bbox则跳过''' """如果当前页面没有bbox则跳过"""
if len(all_bboxes) == 0: if len(all_bboxes) == 0:
logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}") logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], return ocr_construct_page_component_v2(
[], [], interline_equations, fix_discarded_blocks, [],
need_drop, drop_reason) [],
page_id,
page_w,
page_h,
[],
[],
[],
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
'''将span填入blocks中''' """将span填入blocks中"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.3) block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
'''对block进行fix操作''' """对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks) fix_blocks = fix_block_spans_v2(block_with_spans)
'''获取所有line并计算正文line的高度''' """获取所有line并计算正文line的高度"""
line_height = get_line_height(fix_blocks) line_height = get_line_height(fix_blocks)
'''获取所有line并对line排序''' """获取所有line并对line排序"""
sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height) sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
'''根据line的中位数算block的序列关系''' """根据line的中位数算block的序列关系"""
fix_blocks = cal_block_index(fix_blocks, sorted_bboxes) fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
'''重排block''' """将image和table的block还原回group形式参与后续流程"""
fix_blocks = revert_group_blocks(fix_blocks)
"""重排block"""
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index']) sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
'''获取QA需要外置的list''' """获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks) images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
'''构造pdf_info_dict''' """构造pdf_info_dict"""
page_info = ocr_construct_page_component_v2(sorted_blocks, [], page_id, page_w, page_h, [], page_info = ocr_construct_page_component_v2(
images, tables, interline_equations, fix_discarded_blocks, sorted_blocks,
need_drop, drop_reason) [],
page_id,
page_w,
page_h,
[],
images,
tables,
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
return page_info return page_info
def pdf_parse_union(pdf_bytes, def pdf_parse_union(
model_list, dataset: Dataset,
imageWriter, model_list,
parse_mode, imageWriter,
start_page_id=0, parse_mode,
end_page_id=None, start_page_id=0,
debug_mode=False, end_page_id=None,
): debug_mode=False,
pdf_bytes_md5 = compute_md5(pdf_bytes) ):
pdf_docs = fitz.open("pdf", pdf_bytes) pdf_bytes_md5 = compute_md5(dataset.data_bits())
'''初始化空的pdf_info_dict''' """初始化空的pdf_info_dict"""
pdf_info_dict = {} pdf_info_dict = {}
'''用model_list和docs对象初始化magic_model''' """用model_list和docs对象初始化magic_model"""
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, dataset)
'''根据输入的起始范围解析pdf''' """根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1 # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1
)
if end_page_id > len(pdf_docs) - 1: if end_page_id > len(dataset) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length") logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(pdf_docs) - 1 end_page_id = len(dataset) - 1
'''初始化启动时间''' """初始化启动时间"""
start_time = time.time() start_time = time.time()
for page_id, page in enumerate(pdf_docs): for page_id, page in enumerate(dataset):
'''debug时输出每页解析的耗时''' """debug时输出每页解析的耗时."""
if debug_mode: if debug_mode:
time_now = time.time() time_now = time.time()
logger.info( logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}" f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
) )
start_time = time_now start_time = time_now
'''解析pdf中的每一页''' """解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode) page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
)
else: else:
page_w = page.rect.width page_info = page.get_page_info()
page_h = page.rect.height page_w = page_info.w
page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], page_h = page_info.h
[], [], [], [], page_info = ocr_construct_page_component_v2(
True, "skip page") [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
pdf_info_dict[f"page_{page_id}"] = page_info )
pdf_info_dict[f'page_{page_id}'] = page_info
"""分段""" """分段"""
# para_split(pdf_info_dict, debug_mode=debug_mode) para_split(pdf_info_dict, debug_mode=debug_mode)
for page_num, page in pdf_info_dict.items():
page['para_blocks'] = page['preproc_blocks']
"""dict转list""" """dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict) pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = { new_pdf_info_dict = {
"pdf_info": pdf_info_list, 'pdf_info': pdf_info_list,
} }
clean_memory() clean_memory()
......
...@@ -17,7 +17,7 @@ class AbsPipe(ABC): ...@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT = "txt" PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes self.pdf_bytes = pdf_bytes
self.model_list = model_list self.model_list = model_list
self.image_writer = image_writer self.image_writer = image_writer
...@@ -26,6 +26,9 @@ class AbsPipe(ABC): ...@@ -26,6 +26,9 @@ class AbsPipe(ABC):
self.start_page_id = start_page_id self.start_page_id = start_page_id
self.end_page_id = end_page_id self.end_page_id = end_page_id
self.lang = lang self.lang = lang
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def get_compress_pdf_mid_data(self): def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data) return JsonCompressor.compress_json(self.pdf_mid_data)
...@@ -95,9 +98,7 @@ class AbsPipe(ABC): ...@@ -95,9 +98,7 @@ class AbsPipe(ABC):
""" """
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data) pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"] pdf_info_list = pdf_mid_data["pdf_info"]
parse_type = pdf_mid_data["_parse_type"] content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
lang = pdf_mid_data.get("_lang", None)
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path, parse_type, lang)
return content_list return content_list
@staticmethod @staticmethod
...@@ -107,9 +108,7 @@ class AbsPipe(ABC): ...@@ -107,9 +108,7 @@ class AbsPipe(ABC):
""" """
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data) pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"] pdf_info_list = pdf_mid_data["pdf_info"]
parse_type = pdf_mid_data["_parse_type"] md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
lang = pdf_mid_data.get("_lang", None)
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path, parse_type, lang)
return md_content return md_content
...@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf ...@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang) layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
...@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe): ...@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf ...@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe): class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang) layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
...@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe): ...@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf ...@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe): class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
self.pdf_type = jso_useful_key["_pdf_type"] self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang) super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
lang, layout_model, formula_enable, table_enable)
if len(self.model_list) == 0: if len(self.model_list) == 0:
self.input_model_is_empty = True self.input_model_is_empty = True
else: else:
...@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe): ...@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty, is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, is_debug=self.is_debug,
......
from loguru import logger from loguru import logger
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \ from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
calculate_iou calculate_iou, calculate_vertical_projection_overlap_ratio
from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import BlockType from magic_pdf.libs.ocr_content_type import BlockType
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
...@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc ...@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
return all_bboxes, all_discarded_blocks, drop_reasons return all_bboxes, all_discarded_blocks, drop_reasons
def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks, def add_bboxes(blocks, block_type, bboxes):
title_blocks, interline_equation_blocks, page_w, page_h): for block in blocks:
all_bboxes = [] x0, y0, x1, y1 = block['bbox']
all_discarded_blocks = [] if block_type in [
for image in img_blocks: BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
x0, y0, x1, y1 = image['bbox'] BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]]) ]:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
else:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
for table in table_blocks:
x0, y0, x1, y1 = table['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
for text in text_blocks: def ocr_prepare_bboxes_for_layout_split_v2(
x0, y0, x1, y1 = text['bbox'] img_body_blocks, img_caption_blocks, img_footnote_blocks,
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]]) table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
):
all_bboxes = []
for title in title_blocks: add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
x0, y0, x1, y1 = title['bbox'] add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]]) add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
for interline_equation in interline_equation_blocks: add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
x0, y0, x1, y1 = interline_equation['bbox'] add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]]) add_bboxes(text_blocks, BlockType.Text, all_bboxes)
add_bboxes(title_blocks, BlockType.Title, all_bboxes)
add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
'''block嵌套问题解决''' '''block嵌套问题解决'''
'''文本框与标题框重叠,优先信任文本框''' '''文本框与标题框重叠,优先信任文本框'''
...@@ -96,23 +101,47 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b ...@@ -96,23 +101,47 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
'''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框''' '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
# 通过后续大框套小框逻辑删除 # 通过后续大框套小框逻辑删除
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)''' '''discarded_blocks'''
all_discarded_blocks = []
add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
'''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
footnote_blocks = []
for discarded in discarded_blocks: for discarded in discarded_blocks:
x0, y0, x1, y1 = discarded['bbox'] x0, y0, x1, y1 = discarded['bbox']
all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]]) if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# 将footnote加入到all_bboxes中,用来计算layout footnote_blocks.append([x0, y0, x1, y1])
# if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]]) '''移除在footnote下面的任何框'''
need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
if len(need_remove_blocks) > 0:
for block in need_remove_blocks:
all_bboxes.remove(block)
all_discarded_blocks.append(block)
'''经过以上处理后,还存在大框套小框的情况,则删除小框''' '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes = remove_overlaps_min_blocks(all_bboxes) all_bboxes = remove_overlaps_min_blocks(all_bboxes)
all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks) all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
'''将剩余的bbox做分离处理,防止后面分layout时出错''' '''将剩余的bbox做分离处理,防止后面分layout时出错'''
# all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes) all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
return all_bboxes, all_discarded_blocks return all_bboxes, all_discarded_blocks
def find_blocks_under_footnote(all_bboxes, footnote_blocks):
need_remove_blocks = []
for block in all_bboxes:
block_x0, block_y0, block_x1, block_y1 = block[:4]
for footnote_bbox in footnote_blocks:
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes): def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block # 先提取所有text和interline block
text_blocks = [] text_blocks = []
......
...@@ -49,8 +49,7 @@ def merge_spans_to_line(spans): ...@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
continue continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行 # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
current_line[-1]['bbox']):
current_line.append(span) current_line.append(span)
else: else:
# 否则,开始新行 # 否则,开始新行
...@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio): ...@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
'type': block_type, 'type': block_type,
'bbox': block_bbox, 'bbox': block_bbox,
} }
if block_type in [
BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
]:
block_dict["group_id"] = block[-1]
block_spans = [] block_spans = []
for span in spans: for span in spans:
span_bbox = span['bbox'] span_bbox = span['bbox']
...@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks): ...@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
return fix_blocks return fix_blocks
def fix_block_spans_v2(block_with_spans):
"""1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
fix_blocks = []
for block in block_with_spans:
block_type = block['type']
if block_type in [BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
block = fix_text_block(block)
elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
block = fix_interline_block(block)
else:
continue
fix_blocks.append(block)
return fix_blocks
def fix_discarded_block(discarded_block_with_spans): def fix_discarded_block(discarded_block_with_spans):
fix_discarded_blocks = [] fix_discarded_blocks = []
for block in discarded_block_with_spans: for block in discarded_block_with_spans:
......
config:
device: cpu
layout: True
formula: True
table_config:
model: TableMaster
is_table_recog_enable: False
max_time: 400
weights: weights:
layout: Layout/model_final.pth layoutlmv3: Layout/LayoutLMv3/model_final.pth
mfd: MFD/weights.pt doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
mfr: MFR/unimernet_small yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
TableMaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
\ No newline at end of file \ No newline at end of file
...@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""", ...@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""",
help=""" help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional. Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
You should input "Abbreviation" with language form url: You should input "Abbreviation" with language form url:
https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
""", """,
default=None, default=None,
) )
......
...@@ -6,8 +6,8 @@ import click ...@@ -6,8 +6,8 @@ import click
from loguru import logger from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox, from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
draw_model_bbox, draw_line_sort_bbox) draw_model_bbox, draw_span_bbox)
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe from magic_pdf.pipe.TXTPipe import TXTPipe
...@@ -46,10 +46,12 @@ def do_parse( ...@@ -46,10 +46,12 @@ def do_parse(
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
lang=None, lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
): ):
if debug_able: if debug_able:
logger.warning('debug mode is on') logger.warning('debug mode is on')
# f_dump_content_list = True
f_draw_model_bbox = True f_draw_model_bbox = True
f_draw_line_sort_bbox = True f_draw_line_sort_bbox = True
...@@ -64,13 +66,16 @@ def do_parse( ...@@ -64,13 +66,16 @@ def do_parse(
if parse_method == 'auto': if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list} jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True, pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'txt': elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True, pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'ocr': elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True, pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
else: else:
logger.error('unknown parse method') logger.error('unknown parse method')
exit(1) exit(1)
......
...@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr ...@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False): if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr") logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
if input_model_is_empty: if input_model_is_empty:
pdf_models = doc_analyze(pdf_bytes, layout_model = kwargs.get("layout_model", None)
ocr=True, formula_enable = kwargs.get("formula_enable", None)
start_page_id=start_page_id, table_enable = kwargs.get("table_enable", None)
end_page_id=end_page_id, pdf_models = doc_analyze(
lang=lang) pdf_bytes,
ocr=True,
start_page_id=start_page_id,
end_page_id=end_page_id,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pdf_info_dict = parse_pdf(parse_pdf_by_ocr) pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None: if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.") raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment