Commit 7d2dfc80 authored by liukaiwen's avatar liukaiwen
Browse files

Merge branch 'dev' into dev-table-model-update

parents a0eff3be 6d571e2e
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"ocr", SupportedPdfParseMethod.OCR,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
...@@ -9,10 +11,11 @@ def parse_pdf_by_txt( ...@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"txt", SupportedPdfParseMethod.TXT,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
import copy
import os import os
import statistics import statistics
import time import time
from loguru import logger
from typing import List from typing import List
import torch import torch
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
...@@ -15,31 +16,39 @@ from magic_pdf.libs.convert_utils import dict_to_list ...@@ -15,31 +16,39 @@ from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.local_math import float_equal from magic_pdf.libs.local_math import float_equal
from magic_pdf.libs.ocr_content_type import ContentType from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
from magic_pdf.para.para_split_v3 import para_split from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import \
ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \ from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict combine_chars_to_pymudict, remove_chars_in_text_blocks,
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2 replace_equations_in_textblock)
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans, fix_discarded_block from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \ ocr_prepare_bboxes_for_layout_split_v2
remove_overlaps_low_confidence_spans from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap fix_block_spans,
fix_discarded_block, fix_block_spans_v2)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
from magic_pdf.pre_proc.resolve_bbox_conflict import \
check_useful_block_horizontal_overlap
def remove_horizontal_overlap_block_which_smaller(all_bboxes): def remove_horizontal_overlap_block_which_smaller(all_bboxes):
useful_blocks = [] useful_blocks = []
for bbox in all_bboxes: for bbox in all_bboxes:
useful_blocks.append({ useful_blocks.append({'bbox': bbox[:4]})
"bbox": bbox[:4] is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
}) check_useful_block_horizontal_overlap(useful_blocks)
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks) )
if is_useful_block_horz_overlap: if is_useful_block_horz_overlap:
logger.warning( logger.warning(
f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}") f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
) # noqa: E501
for bbox in all_bboxes.copy(): for bbox in all_bboxes.copy():
if smaller_bbox == bbox[:4]: if smaller_bbox == bbox[:4]:
all_bboxes.remove(bbox) all_bboxes.remove(bbox)
...@@ -47,27 +56,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes): ...@@ -47,27 +56,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
return is_useful_block_horz_overlap, all_bboxes return is_useful_block_horz_overlap, all_bboxes
def __replace_STX_ETX(text_str:str): def __replace_STX_ETX(text_str: str):
""" Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks. """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far. Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args: Args:
text_str (str): raw text text_str (str): raw text
Returns: Returns:
_type_: replaced text _type_: replaced text
""" """ # noqa: E501
if text_str: if text_str:
s = text_str.replace('\u0002', "'") s = text_str.replace('\u0002', "'")
s = s.replace("\u0003", "'") s = s.replace('\u0003', "'")
return s return s
return text_str return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations): def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"] text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[ char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
"blocks" 'blocks'
] ]
text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks) text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
text_blocks = replace_equations_in_textblock( text_blocks = replace_equations_in_textblock(
...@@ -77,54 +86,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations): ...@@ -77,54 +86,63 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_blocks = remove_chars_in_text_blocks(text_blocks) text_blocks = remove_chars_in_text_blocks(text_blocks)
spans = [] spans = []
for v in text_blocks: for v in text_blocks:
for line in v["lines"]: for line in v['lines']:
for span in line["spans"]: for span in line['spans']:
bbox = span["bbox"] bbox = span['bbox']
if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]): if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
continue continue
if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation): if span.get('type') not in (
ContentType.InlineEquation,
ContentType.InterlineEquation,
):
spans.append( spans.append(
{ {
"bbox": list(span["bbox"]), 'bbox': list(span['bbox']),
"content": __replace_STX_ETX(span["text"]), 'content': __replace_STX_ETX(span['text']),
"type": ContentType.Text, 'type': ContentType.Text,
"score": 1.0, 'score': 1.0,
} }
) )
return spans return spans
def replace_text_span(pymu_spans, ocr_spans): def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str): def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device('cuda')
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
supports_bfloat16 = True supports_bfloat16 = True
else: else:
supports_bfloat16 = False supports_bfloat16 = False
else: else:
device = torch.device("cpu") device = torch.device('cpu')
supports_bfloat16 = False supports_bfloat16 = False
if model_name == "layoutreader": if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在 # 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir() layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir): if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir) model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
)
else: else:
logger.warning( logger.warning(
f"local layoutreader model not exists, use online model from huggingface") 'local layoutreader model not exists, use online model from huggingface'
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader") )
model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader'
)
# 检查设备是否支持 bfloat16 # 检查设备是否支持 bfloat16
if supports_bfloat16: if supports_bfloat16:
model.bfloat16() model.bfloat16()
model.to(device).eval() model.to(device).eval()
else: else:
logger.error("model name not allow") logger.error('model name not allow')
exit(1) exit(1)
return model return model
...@@ -145,7 +163,9 @@ class ModelSingleton: ...@@ -145,7 +163,9 @@ class ModelSingleton:
def do_predict(boxes: List[List[int]], model) -> List[int]: def do_predict(boxes: List[List[int]], model) -> List[int]:
from magic_pdf.model.v3.helpers import prepare_inputs, boxes2inputs, parse_logits from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
prepare_inputs)
inputs = boxes2inputs(boxes) inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model) inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0) logits = model(**inputs).logits.cpu().squeeze(0)
...@@ -154,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]: ...@@ -154,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
def cal_block_index(fix_blocks, sorted_bboxes): def cal_block_index(fix_blocks, sorted_bboxes):
for block in fix_blocks: for block in fix_blocks:
# if block['type'] in ['text', 'title', 'interline_equation']:
# line_index_list = []
# if len(block['lines']) == 0:
# block['index'] = sorted_bboxes.index(block['bbox'])
# else:
# for line in block['lines']:
# line['index'] = sorted_bboxes.index(line['bbox'])
# line_index_list.append(line['index'])
# median_value = statistics.median(line_index_list)
# block['index'] = median_value
#
# elif block['type'] in ['table', 'image']:
# block['index'] = sorted_bboxes.index(block['bbox'])
line_index_list = [] line_index_list = []
if len(block['lines']) == 0: if len(block['lines']) == 0:
...@@ -178,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes): ...@@ -178,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
median_value = statistics.median(line_index_list) median_value = statistics.median(line_index_list)
block['index'] = median_value block['index'] = median_value
# 删除图表block中的虚拟line信息 # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in ['table', 'image']: if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
del block['lines'] block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
return fix_blocks return fix_blocks
...@@ -193,21 +202,22 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -193,21 +202,22 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
block_weight = x1 - x0 block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox # 如果block高度小于n行正文,则直接返回block的bbox
if line_height*3 < block_height: if line_height * 3 < block_height:
if block_height > page_h*0.25 and page_w*0.5 > block_weight > page_w*0.25: # 可能是双列结构,可以切细点 if (
lines = int(block_height/line_height)+1 block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
lines = int(block_height / line_height) + 1
else: else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行 # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w*0.4: if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3 line_height = (y1 - y0) / 3
lines = 3 lines = 3
elif block_weight > page_w*0.25: # 否则将block分成两行 elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
line_height = (y1 - y0) / 2 lines = int(block_height / line_height) + 1
lines = 2 else: # 判断长宽比
else: # 判断长宽比 if block_height / block_weight > 1.2: # 细长的不分
if block_height/block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]] return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行 else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2 line_height = (y1 - y0) / 2
lines = 2 lines = 2
...@@ -229,7 +239,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -229,7 +239,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = [] page_line_list = []
for block in fix_blocks: for block in fix_blocks:
if block['type'] in ['text', 'title', 'interline_equation']: if block['type'] in [
BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
if len(block['lines']) == 0: if len(block['lines']) == 0:
bbox = block['bbox'] bbox = block['bbox']
lines = insert_lines_into_block(bbox, line_height, page_w, page_h) lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
...@@ -240,8 +254,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -240,8 +254,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
page_line_list.append(bbox) page_line_list.append(bbox)
elif block['type'] in ['table', 'image']: elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
bbox = block['bbox'] bbox = block['bbox']
block["real_lines"] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h) lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
block['lines'] = [] block['lines'] = []
for line in lines: for line in lines:
...@@ -256,19 +271,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -256,19 +271,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
for left, top, right, bottom in page_line_list: for left, top, right, bottom in page_line_list:
if left < 0: if left < 0:
logger.warning( logger.warning(
f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
left = 0 left = 0
if right > page_w: if right > page_w:
logger.warning( logger.warning(
f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
right = page_w right = page_w
if top < 0: if top < 0:
logger.warning( logger.warning(
f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
top = 0 top = 0
if bottom > page_h: if bottom > page_h:
logger.warning( logger.warning(
f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}") f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
bottom = page_h bottom = page_h
left = round(left * x_scale) left = round(left * x_scale)
...@@ -276,11 +295,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -276,11 +295,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
right = round(right * x_scale) right = round(right * x_scale)
bottom = round(bottom * y_scale) bottom = round(bottom * y_scale)
assert ( assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0 1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}" ), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}' # noqa: E126, E121
boxes.append([left, top, right, bottom]) boxes.append([left, top, right, bottom])
model_manager = ModelSingleton() model_manager = ModelSingleton()
model = model_manager.get_model("layoutreader") model = model_manager.get_model('layoutreader')
with torch.no_grad(): with torch.no_grad():
orders = do_predict(boxes, model) orders = do_predict(boxes, model)
sorted_bboxes = [page_line_list[i] for i in orders] sorted_bboxes = [page_line_list[i] for i in orders]
...@@ -291,149 +310,274 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): ...@@ -291,149 +310,274 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
def get_line_height(blocks): def get_line_height(blocks):
page_line_height_list = [] page_line_height_list = []
for block in blocks: for block in blocks:
if block['type'] in ['text', 'title', 'interline_equation']: if block['type'] in [
BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
page_line_height_list.append(int(bbox[3]-bbox[1])) page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0: if len(page_line_height_list) > 0:
return statistics.median(page_line_height_list) return statistics.median(page_line_height_list)
else: else:
return 10 return 10
def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode): def process_groups(groups, body_key, caption_key, footnote_key):
body_blocks = []
caption_blocks = []
footnote_blocks = []
for i, group in enumerate(groups):
group[body_key]['group_id'] = i
body_blocks.append(group[body_key])
for caption_block in group[caption_key]:
caption_block['group_id'] = i
caption_blocks.append(caption_block)
for footnote_block in group[footnote_key]:
footnote_block['group_id'] = i
footnote_blocks.append(footnote_block)
return body_blocks, caption_blocks, footnote_blocks
def process_block_list(blocks, body_type, block_type):
indices = [block['index'] for block in blocks]
median_index = statistics.median(indices)
body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
return {
'type': block_type,
'bbox': body_bbox,
'blocks': blocks,
'index': median_index,
}
def revert_group_blocks(blocks):
image_groups = {}
table_groups = {}
new_blocks = []
for block in blocks:
if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
group_id = block['group_id']
if group_id not in image_groups:
image_groups[group_id] = []
image_groups[group_id].append(block)
elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
group_id = block['group_id']
if group_id not in table_groups:
table_groups[group_id] = []
table_groups[group_id].append(block)
else:
new_blocks.append(block)
for group_id, blocks in image_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
for group_id, blocks in table_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
return new_blocks
def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
):
need_drop = False need_drop = False
drop_reason = [] drop_reason = []
'''从magic_model对象中获取后面会用到的区块信息''' """从magic_model对象中获取后面会用到的区块信息"""
img_blocks = magic_model.get_imgs(page_id) # img_blocks = magic_model.get_imgs(page_id)
table_blocks = magic_model.get_tables(page_id) # table_blocks = magic_model.get_tables(page_id)
img_groups = magic_model.get_imgs_v2(page_id)
table_groups = magic_model.get_tables_v2(page_id)
img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
)
discarded_blocks = magic_model.get_discarded(page_id) discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id) text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id) title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id) inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id) page_w, page_h = magic_model.get_page_size(page_id)
spans = magic_model.get_all_spans(page_id) spans = magic_model.get_all_spans(page_id)
'''根据parse_mode,构造spans''' """根据parse_mode,构造spans"""
if parse_mode == "txt": if parse_mode == SupportedPdfParseMethod.TXT:
"""ocr 中文本类的 span 用 pymu spans 替换!""" """ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract( pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
pdf_docs[page_id], inline_equations, interline_equations
)
spans = replace_text_span(pymu_spans, spans) spans = replace_text_span(pymu_spans, spans)
elif parse_mode == "ocr": elif parse_mode == SupportedPdfParseMethod.OCR:
pass pass
else: else:
raise Exception("parse_mode must be txt or ocr") raise Exception('parse_mode must be txt or ocr')
'''删除重叠spans中置信度较低的那些''' """删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans) spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些''' """删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans) spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图''' """对image和table截图"""
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter) spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
'''将所有区块的bbox整理到一起''' """将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上 # interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = [] interline_equation_blocks = []
if len(interline_equation_blocks) > 0: if len(interline_equation_blocks) > 0:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2( all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, img_body_blocks, img_caption_blocks, img_footnote_blocks,
interline_equation_blocks, page_w, page_h) table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else: else:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2( all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, img_body_blocks, img_caption_blocks, img_footnote_blocks,
interline_equations, page_w, page_h) table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
'''先处理不需要排版的discarded_blocks''' """先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4) discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans) fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
'''如果当前页面没有bbox则跳过''' """如果当前页面没有bbox则跳过"""
if len(all_bboxes) == 0: if len(all_bboxes) == 0:
logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}") logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], return ocr_construct_page_component_v2(
[], [], interline_equations, fix_discarded_blocks, [],
need_drop, drop_reason) [],
page_id,
page_w,
page_h,
[],
[],
[],
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
'''将span填入blocks中''' """将span填入blocks中"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.3) block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
'''对block进行fix操作''' """对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks) fix_blocks = fix_block_spans_v2(block_with_spans)
'''获取所有line并计算正文line的高度''' """获取所有line并计算正文line的高度"""
line_height = get_line_height(fix_blocks) line_height = get_line_height(fix_blocks)
'''获取所有line并对line排序''' """获取所有line并对line排序"""
sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height) sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
'''根据line的中位数算block的序列关系''' """根据line的中位数算block的序列关系"""
fix_blocks = cal_block_index(fix_blocks, sorted_bboxes) fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
'''重排block''' """将image和table的block还原回group形式参与后续流程"""
fix_blocks = revert_group_blocks(fix_blocks)
"""重排block"""
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index']) sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
'''获取QA需要外置的list''' """获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks) images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
'''构造pdf_info_dict''' """构造pdf_info_dict"""
page_info = ocr_construct_page_component_v2(sorted_blocks, [], page_id, page_w, page_h, [], page_info = ocr_construct_page_component_v2(
images, tables, interline_equations, fix_discarded_blocks, sorted_blocks,
need_drop, drop_reason) [],
page_id,
page_w,
page_h,
[],
images,
tables,
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
return page_info return page_info
def pdf_parse_union(pdf_bytes, def pdf_parse_union(
model_list, dataset: Dataset,
imageWriter, model_list,
parse_mode, imageWriter,
start_page_id=0, parse_mode,
end_page_id=None, start_page_id=0,
debug_mode=False, end_page_id=None,
): debug_mode=False,
pdf_bytes_md5 = compute_md5(pdf_bytes) ):
pdf_docs = fitz.open("pdf", pdf_bytes) pdf_bytes_md5 = compute_md5(dataset.data_bits())
'''初始化空的pdf_info_dict''' """初始化空的pdf_info_dict"""
pdf_info_dict = {} pdf_info_dict = {}
'''用model_list和docs对象初始化magic_model''' """用model_list和docs对象初始化magic_model"""
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, dataset)
'''根据输入的起始范围解析pdf''' """根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1 # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1
)
if end_page_id > len(pdf_docs) - 1: if end_page_id > len(dataset) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length") logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(pdf_docs) - 1 end_page_id = len(dataset) - 1
'''初始化启动时间''' """初始化启动时间"""
start_time = time.time() start_time = time.time()
for page_id, page in enumerate(pdf_docs): for page_id, page in enumerate(dataset):
'''debug时输出每页解析的耗时''' """debug时输出每页解析的耗时."""
if debug_mode: if debug_mode:
time_now = time.time() time_now = time.time()
logger.info( logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}" f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
) )
start_time = time_now start_time = time_now
'''解析pdf中的每一页''' """解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode) page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
)
else: else:
page_w = page.rect.width page_info = page.get_page_info()
page_h = page.rect.height page_w = page_info.w
page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], page_h = page_info.h
[], [], [], [], page_info = ocr_construct_page_component_v2(
True, "skip page") [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
pdf_info_dict[f"page_{page_id}"] = page_info )
pdf_info_dict[f'page_{page_id}'] = page_info
"""分段""" """分段"""
para_split(pdf_info_dict, debug_mode=debug_mode) para_split(pdf_info_dict, debug_mode=debug_mode)
...@@ -441,7 +585,7 @@ def pdf_parse_union(pdf_bytes, ...@@ -441,7 +585,7 @@ def pdf_parse_union(pdf_bytes,
"""dict转list""" """dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict) pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = { new_pdf_info_dict = {
"pdf_info": pdf_info_list, 'pdf_info': pdf_info_list,
} }
clean_memory() clean_memory()
......
...@@ -17,7 +17,7 @@ class AbsPipe(ABC): ...@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT = "txt" PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes self.pdf_bytes = pdf_bytes
self.model_list = model_list self.model_list = model_list
self.image_writer = image_writer self.image_writer = image_writer
...@@ -26,6 +26,9 @@ class AbsPipe(ABC): ...@@ -26,6 +26,9 @@ class AbsPipe(ABC):
self.start_page_id = start_page_id self.start_page_id = start_page_id
self.end_page_id = end_page_id self.end_page_id = end_page_id
self.lang = lang self.lang = lang
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def get_compress_pdf_mid_data(self): def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data) return JsonCompressor.compress_json(self.pdf_mid_data)
...@@ -95,9 +98,7 @@ class AbsPipe(ABC): ...@@ -95,9 +98,7 @@ class AbsPipe(ABC):
""" """
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data) pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"] pdf_info_list = pdf_mid_data["pdf_info"]
parse_type = pdf_mid_data["_parse_type"] content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
lang = pdf_mid_data.get("_lang", None)
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path, parse_type, lang)
return content_list return content_list
@staticmethod @staticmethod
...@@ -107,9 +108,7 @@ class AbsPipe(ABC): ...@@ -107,9 +108,7 @@ class AbsPipe(ABC):
""" """
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data) pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"] pdf_info_list = pdf_mid_data["pdf_info"]
parse_type = pdf_mid_data["_parse_type"] md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
lang = pdf_mid_data.get("_lang", None)
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path, parse_type, lang)
return md_content return md_content
...@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf ...@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang) layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
...@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe): ...@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf ...@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe): class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang) layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
...@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe): ...@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf ...@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe): class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None): start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
self.pdf_type = jso_useful_key["_pdf_type"] self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang) super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
lang, layout_model, formula_enable, table_enable)
if len(self.model_list) == 0: if len(self.model_list) == 0:
self.input_model_is_empty = True self.input_model_is_empty = True
else: else:
...@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe): ...@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty, is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang) lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, is_debug=self.is_debug,
......
from loguru import logger from loguru import logger
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \ from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
calculate_iou calculate_iou, calculate_vertical_projection_overlap_ratio
from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import BlockType from magic_pdf.libs.ocr_content_type import BlockType
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
...@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc ...@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
return all_bboxes, all_discarded_blocks, drop_reasons return all_bboxes, all_discarded_blocks, drop_reasons
def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks, def add_bboxes(blocks, block_type, bboxes):
title_blocks, interline_equation_blocks, page_w, page_h): for block in blocks:
all_bboxes = [] x0, y0, x1, y1 = block['bbox']
all_discarded_blocks = [] if block_type in [
for image in img_blocks: BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
x0, y0, x1, y1 = image['bbox'] BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]]) ]:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
for table in table_blocks: else:
x0, y0, x1, y1 = table['bbox'] bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
for text in text_blocks:
x0, y0, x1, y1 = text['bbox']
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]])
for title in title_blocks: def ocr_prepare_bboxes_for_layout_split_v2(
x0, y0, x1, y1 = title['bbox'] img_body_blocks, img_caption_blocks, img_footnote_blocks,
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]]) table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
):
all_bboxes = []
for interline_equation in interline_equation_blocks: add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
x0, y0, x1, y1 = interline_equation['bbox'] add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]]) add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
add_bboxes(text_blocks, BlockType.Text, all_bboxes)
add_bboxes(title_blocks, BlockType.Title, all_bboxes)
add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
'''block嵌套问题解决''' '''block嵌套问题解决'''
'''文本框与标题框重叠,优先信任文本框''' '''文本框与标题框重叠,优先信任文本框'''
...@@ -96,13 +101,23 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b ...@@ -96,13 +101,23 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
'''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框''' '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
# 通过后续大框套小框逻辑删除 # 通过后续大框套小框逻辑删除
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)''' '''discarded_blocks'''
all_discarded_blocks = []
add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
'''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
footnote_blocks = []
for discarded in discarded_blocks: for discarded in discarded_blocks:
x0, y0, x1, y1 = discarded['bbox'] x0, y0, x1, y1 = discarded['bbox']
all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]]) if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# 将footnote加入到all_bboxes中,用来计算layout footnote_blocks.append([x0, y0, x1, y1])
# if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]]) '''移除在footnote下面的任何框'''
need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
if len(need_remove_blocks) > 0:
for block in need_remove_blocks:
all_bboxes.remove(block)
all_discarded_blocks.append(block)
'''经过以上处理后,还存在大框套小框的情况,则删除小框''' '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes = remove_overlaps_min_blocks(all_bboxes) all_bboxes = remove_overlaps_min_blocks(all_bboxes)
...@@ -113,6 +128,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b ...@@ -113,6 +128,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
return all_bboxes, all_discarded_blocks return all_bboxes, all_discarded_blocks
def find_blocks_under_footnote(all_bboxes, footnote_blocks):
need_remove_blocks = []
for block in all_bboxes:
block_x0, block_y0, block_x1, block_y1 = block[:4]
for footnote_bbox in footnote_blocks:
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes): def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block # 先提取所有text和interline block
text_blocks = [] text_blocks = []
......
...@@ -49,7 +49,7 @@ def merge_spans_to_line(spans): ...@@ -49,7 +49,7 @@ def merge_spans_to_line(spans):
continue continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行 # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.6): if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
current_line.append(span) current_line.append(span)
else: else:
# 否则,开始新行 # 否则,开始新行
...@@ -153,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio): ...@@ -153,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
'type': block_type, 'type': block_type,
'bbox': block_bbox, 'bbox': block_bbox,
} }
if block_type in [
BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
]:
block_dict["group_id"] = block[-1]
block_spans = [] block_spans = []
for span in spans: for span in spans:
span_bbox = span['bbox'] span_bbox = span['bbox']
...@@ -201,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks): ...@@ -201,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
return fix_blocks return fix_blocks
def fix_block_spans_v2(block_with_spans):
"""1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
fix_blocks = []
for block in block_with_spans:
block_type = block['type']
if block_type in [BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
block = fix_text_block(block)
elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
block = fix_interline_block(block)
else:
continue
fix_blocks.append(block)
return fix_blocks
def fix_discarded_block(discarded_block_with_spans): def fix_discarded_block(discarded_block_with_spans):
fix_discarded_blocks = [] fix_discarded_blocks = []
for block in discarded_block_with_spans: for block in discarded_block_with_spans:
......
config:
device: cpu
layout: True
formula: True
table_config:
model: TableMaster
is_table_recog_enable: False
max_time: 400
weights: weights:
layout: Layout/model_final.pth layoutlmv3: Layout/LayoutLMv3/model_final.pth
mfd: MFD/weights.pt doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
mfr: MFR/unimernet_small yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
TableMaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
\ No newline at end of file \ No newline at end of file
...@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""", ...@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""",
help=""" help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional. Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
You should input "Abbreviation" with language form url: You should input "Abbreviation" with language form url:
https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
""", """,
default=None, default=None,
) )
......
...@@ -6,8 +6,8 @@ import click ...@@ -6,8 +6,8 @@ import click
from loguru import logger from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox, from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
draw_model_bbox, draw_line_sort_bbox) draw_model_bbox, draw_span_bbox)
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe from magic_pdf.pipe.TXTPipe import TXTPipe
...@@ -46,10 +46,12 @@ def do_parse( ...@@ -46,10 +46,12 @@ def do_parse(
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
lang=None, lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
): ):
if debug_able: if debug_able:
logger.warning('debug mode is on') logger.warning('debug mode is on')
# f_dump_content_list = True
f_draw_model_bbox = True f_draw_model_bbox = True
f_draw_line_sort_bbox = True f_draw_line_sort_bbox = True
...@@ -64,13 +66,16 @@ def do_parse( ...@@ -64,13 +66,16 @@ def do_parse(
if parse_method == 'auto': if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list} jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True, pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'txt': elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True, pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'ocr': elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True, pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
else: else:
logger.error('unknown parse method') logger.error('unknown parse method')
exit(1) exit(1)
......
...@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr ...@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False): if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr") logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
if input_model_is_empty: if input_model_is_empty:
pdf_models = doc_analyze(pdf_bytes, layout_model = kwargs.get("layout_model", None)
ocr=True, formula_enable = kwargs.get("formula_enable", None)
start_page_id=start_page_id, table_enable = kwargs.get("table_enable", None)
end_page_id=end_page_id, pdf_models = doc_analyze(
lang=lang) pdf_bytes,
ocr=True,
start_page_id=start_page_id,
end_page_id=end_page_id,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pdf_info_dict = parse_pdf(parse_pdf_by_ocr) pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None: if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.") raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
......
from loguru import logger
def ImportPIL(f):
try:
import PIL # noqa: F401
except ImportError:
logger.error('Pillow not installed, please install by pip.')
exit(1)
return f
Data Api
------------------
.. toctree::
:maxdepth: 2
api/dataset.rst
api/data_reader_writer.rst
api/read_api.rst
Data Reader Writer
--------------------
.. autoclass:: magic_pdf.data.data_reader_writer.DataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.DataWriter
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.FileBasedDataWriter
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.S3DataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.S3DataWriter
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataReader
:members:
:inherited-members:
.. autoclass:: magic_pdf.data.data_reader_writer.MultiBucketS3DataWriter
:members:
:inherited-members:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment