Unverified Commit 3a42ebbf authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #838 from opendatalab/release-0.9.0

Release 0.9.0
parents 765c6d77 14024793
from struct_eqtable.model import StructTable from loguru import logger
try:
from struct_eqtable.model import StructTable
except ImportError:
logger.error("StructEqTable is under upgrade, the current version does not support it.")
from pypandoc import convert_text from pypandoc import convert_text
class StructTableModel: class StructTableModel:
def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'): def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
# init # init
......
...@@ -52,11 +52,11 @@ class ppTableModel(object): ...@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
device = kwargs.get("device", "cpu") device = kwargs.get("device", "cpu")
use_gpu = True if device == "cuda" else False use_gpu = True if device.startswith("cuda") else False
config = { config = {
"use_gpu": use_gpu, "use_gpu": use_gpu,
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN), "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
"table_algorithm": TABLE_MASTER, "table_algorithm": "TableMaster",
"table_model_dir": table_model_dir, "table_model_dir": table_model_dir,
"table_char_dict_path": table_char_dict_path, "table_char_dict_path": table_char_dict_path,
"det_model_dir": det_model_dir, "det_model_dir": det_model_dir,
......
...@@ -18,8 +18,11 @@ def region_to_bbox(region): ...@@ -18,8 +18,11 @@ def region_to_bbox(region):
class CustomPaddleModel: class CustomPaddleModel:
def __init__(self, ocr: bool = False, show_log: bool = False): def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log) if lang is not None:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
else:
self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
def __call__(self, img): def __call__(self, img):
try: try:
......
from collections import defaultdict
from typing import List, Dict
import torch
from transformers import LayoutLMv3ForTokenClassification
MAX_LEN = 510
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2
class DataCollator:
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
bbox = []
labels = []
input_ids = []
attention_mask = []
# clip bbox and labels to max length, build input_ids and attention_mask
for feature in features:
_bbox = feature["source_boxes"]
if len(_bbox) > MAX_LEN:
_bbox = _bbox[:MAX_LEN]
_labels = feature["target_index"]
if len(_labels) > MAX_LEN:
_labels = _labels[:MAX_LEN]
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
_attention_mask = [1] * len(_bbox)
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
bbox.append(_bbox)
labels.append(_labels)
input_ids.append(_input_ids)
attention_mask.append(_attention_mask)
# add CLS and EOS tokens
for i in range(len(bbox)):
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
labels[i] = [-100] + labels[i] + [-100]
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
attention_mask[i] = [1] + attention_mask[i] + [1]
# padding to max length
max_len = max(len(x) for x in bbox)
for i in range(len(bbox)):
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
attention_mask[i] = attention_mask[i] + [0] * (
max_len - len(attention_mask[i])
)
ret = {
"bbox": torch.tensor(bbox),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels),
"input_ids": torch.tensor(input_ids),
}
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
ret["labels"][ret["labels"] > MAX_LEN] = -100
# set label > 0 to label-1, because original labels are 1-indexed
ret["labels"][ret["labels"] > 0] -= 1
return ret
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}
def prepare_inputs(
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
) -> Dict[str, torch.Tensor]:
ret = {}
for k, v in inputs.items():
v = v.to(model.device)
if torch.is_floating_point(v):
v = v.to(model.dtype)
ret[k] = v
return ret
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
"""
parse logits to orders
:param logits: logits from model
:param length: input length
:return: orders
"""
logits = logits[1 : length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# filter idxes len > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# filter
for order, idxes in order_to_idxes.items():
# find original logits of idxes
idxes_to_logit = {}
for idx in idxes:
idxes_to_logit[idx] = logits[idx, order]
idxes_to_logit = sorted(
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
)
# keep the highest logit as order, set others to next candidate
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()
return ret
def check_duplicate(a: List[int]) -> bool:
return len(a) != len(set(a))
import copy
from loguru import logger
from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
LIST_END_FLAG = ('.', '。', ';', ';')
class ListLineTag:
IS_LIST_START_LINE = "is_list_start_line"
IS_LIST_END_LINE = "is_list_end_line"
def __process_blocks(blocks):
# 对所有block预处理
# 1.通过title和interline_equation将block分组
# 2.bbox边界根据line信息重置
result = []
current_group = []
for i in range(len(blocks)):
current_block = blocks[i]
# 如果当前块是 text 类型
if current_block['type'] == 'text':
current_block["bbox_fs"] = copy.deepcopy(current_block["bbox"])
if 'lines' in current_block and len(current_block["lines"]) > 0:
current_block['bbox_fs'] = [min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']])]
current_group.append(current_block)
# 检查下一个块是否存在
if i + 1 < len(blocks):
next_block = blocks[i + 1]
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
if next_block['type'] in ['title', 'interline_equation']:
result.append(current_group)
current_group = []
# 处理最后一个 group
if current_group:
result.append(current_group)
return result
def __is_list_or_index_block(block):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
# index block 是一种特殊的list block
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if len(block['lines']) >= 2:
first_line = block['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
left_close_num = 0
left_not_close_num = 0
right_not_close_num = 0
right_close_num = 0
lines_text_list = []
multiple_para_flag = False
last_line = block['lines'][-1]
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and
# block['bbox_fs'][2] - first_line['bbox'][2] < line_height and
abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and
block['bbox_fs'][2] - last_line['bbox'][2] > line_height
):
multiple_para_flag = True
for line in block['lines']:
line_text = ""
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.Text:
line_text += span['content'].strip()
lines_text_list.append(line_text)
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
left_close_num += 1
elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
# logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
left_not_close_num += 1
# 计算右侧是否顶格
if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
right_close_num += 1
else:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
closed_area = 0.3 * block_weight
# closed_area = 5 * line_height
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag = False
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag = False
num_start_count = 0
num_end_count = 0
flag_end_count = 0
if len(lines_text_list) > 0:
for line_text in lines_text_list:
if len(line_text) > 0:
if line_text[-1] in LIST_END_FLAG:
flag_end_count += 1
if line_text[0].isdigit():
num_start_count += 1
if line_text[-1].isdigit():
num_end_count += 1
if flag_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True
if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
line_num_flag = True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if ((left_close_num/len(block['lines']) >= 0.8 or right_close_num/len(block['lines']) >= 0.8)
and line_num_flag
):
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.Index
elif left_close_num >= 2 and (
right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2) and not multiple_para_flag:
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if left_close_num / len(block['lines']) > 0.9:
# 这种是每个item只有一行,且左边都贴边的短item list
if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif line_end_flag:
for i, line in enumerate(block['lines']):
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
if i + 1 < len(block['lines']):
block['lines'][i+1][ListLineTag.IS_LIST_START_LINE] = True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else:
line_start_flag = False
for i, line in enumerate(block['lines']):
if line_start_flag:
line[ListLineTag.IS_LIST_START_LINE] = True
line_start_flag = False
elif abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
line_start_flag = True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_LINE 结尾且数量和start line 一致
elif num_start_count >= 2 and num_start_count == flag_end_count: # 简单一点先不考虑左侧不贴边的情况
for i, line in enumerate(block['lines']):
if lines_text_list[i][0].isdigit():
line[ListLineTag.IS_LIST_START_LINE] = True
if lines_text_list[i][-1] in LIST_END_FLAG:
line[ListLineTag.IS_LIST_END_LINE] = True
else:
# 正常有缩进的list处理
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True
return BlockType.List
else:
return BlockType.Text
else:
return BlockType.Text
def __merge_2_text_blocks(block1, block2):
if len(block1['lines']) > 0:
first_line = block1['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block1_weight = block1['bbox'][2] - block1['bbox'][0]
block2_weight = block2['bbox'][2] - block2['bbox'][0]
min_block_weight = min(block1_weight, block2_weight)
if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
last_line = block2['lines'][-1]
if len(last_line['spans']) > 0:
last_span = last_line['spans'][-1]
line_height = last_line['bbox'][3] - last_line['bbox'][1]
if (abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height and
not last_span['content'].endswith(LINE_STOP_FLAG) and
# 两个block宽度差距超过2倍也不合并
abs(block1_weight - block2_weight) < min_block_weight
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
return block1, block2
def __merge_2_list_blocks(block1, block2):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
return block1, block2
def __is_list_group(text_blocks_group):
# list group的特征是一个group内的所有block都满足以下条件
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
for block in text_blocks_group:
if len(block['lines']) > 3:
return False
return True
def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups:
if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list or index block
for block in text_blocks_group:
block_type = __is_list_or_index_block(block)
block['type'] = block_type
# logger.info(f"{block['type']}:{block}")
if len(text_blocks_group) > 1:
# 在合并前判断这个group 是否是一个 list group
is_list_group = __is_list_group(text_blocks_group)
# 倒序遍历
for i in range(len(text_blocks_group) - 1, -1, -1):
current_block = text_blocks_group[i]
# 检查是否有前一个块
if i - 1 >= 0:
prev_block = text_blocks_group[i - 1]
if current_block['type'] == 'text' and prev_block['type'] == 'text' and not is_list_group:
__merge_2_text_blocks(current_block, prev_block)
elif (
(current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List) or
(current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index)
):
__merge_2_list_blocks(current_block, prev_block)
else:
continue
def para_split(pdf_info_dict, debug_mode=False):
all_blocks = []
for page_num, page in pdf_info_dict.items():
blocks = copy.deepcopy(page['preproc_blocks'])
for block in blocks:
block['page_num'] = page_num
all_blocks.extend(blocks)
__para_merge_page(all_blocks)
for page_num, page in pdf_info_dict.items():
page['para_blocks'] = []
for block in all_blocks:
if block['page_num'] == page_num:
page['para_blocks'].append(block)
if __name__ == '__main__':
input_blocks = []
# 调用函数
groups = __process_blocks(input_blocks)
for group_index, group in enumerate(groups):
print(f"Group {group_index}: {group}")
from magic_pdf.pdf_parse_union_core import pdf_parse_union from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_ocr(pdf_bytes, def parse_pdf_by_ocr(pdf_bytes,
...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"ocr", SupportedPdfParseMethod.OCR,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
from magic_pdf.pdf_parse_union_core import pdf_parse_union from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_txt( def parse_pdf_by_txt(
...@@ -9,10 +11,11 @@ def parse_pdf_by_txt( ...@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
return pdf_parse_union(pdf_bytes, dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
model_list, model_list,
imageWriter, imageWriter,
"txt", SupportedPdfParseMethod.TXT,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
......
import copy
import os
import statistics
import time
from typing import List
import torch
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.local_math import float_equal
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import \
ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict, remove_chars_in_text_blocks,
replace_equations_in_textblock)
from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
ocr_prepare_bboxes_for_layout_split_v2
from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
fix_block_spans,
fix_discarded_block, fix_block_spans_v2)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
from magic_pdf.pre_proc.resolve_bbox_conflict import \
check_useful_block_horizontal_overlap
def remove_horizontal_overlap_block_which_smaller(all_bboxes):
useful_blocks = []
for bbox in all_bboxes:
useful_blocks.append({'bbox': bbox[:4]})
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
check_useful_block_horizontal_overlap(useful_blocks)
)
if is_useful_block_horz_overlap:
logger.warning(
f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
) # noqa: E501
for bbox in all_bboxes.copy():
if smaller_bbox == bbox[:4]:
all_bboxes.remove(bbox)
return is_useful_block_horz_overlap, all_bboxes
def __replace_STX_ETX(text_str: str):
"""Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args:
text_str (str): raw text
Returns:
_type_: replaced text
""" # noqa: E501
if text_str:
s = text_str.replace('\u0002', "'")
s = s.replace('\u0003', "'")
return s
return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
'blocks'
]
text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
text_blocks = replace_equations_in_textblock(
text_blocks, inline_equations, interline_equations
)
text_blocks = remove_citation_marker(text_blocks)
text_blocks = remove_chars_in_text_blocks(text_blocks)
spans = []
for v in text_blocks:
for line in v['lines']:
for span in line['spans']:
bbox = span['bbox']
if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
continue
if span.get('type') not in (
ContentType.InlineEquation,
ContentType.InterlineEquation,
):
spans.append(
{
'bbox': list(span['bbox']),
'content': __replace_STX_ETX(span['text']),
'type': ContentType.Text,
'score': 1.0,
}
)
return spans
def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available():
device = torch.device('cuda')
if torch.cuda.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
else:
device = torch.device('cpu')
supports_bfloat16 = False
if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
)
else:
logger.warning(
'local layoutreader model not exists, use online model from huggingface'
)
model = LayoutLMv3ForTokenClassification.from_pretrained(
'hantian/layoutreader'
)
# 检查设备是否支持 bfloat16
if supports_bfloat16:
model.bfloat16()
model.to(device).eval()
else:
logger.error('model name not allow')
exit(1)
return model
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, model_name: str):
if model_name not in self._models:
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]
def do_predict(boxes: List[List[int]], model) -> List[int]:
from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits,
prepare_inputs)
inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model)
logits = model(**inputs).logits.cpu().squeeze(0)
return parse_logits(logits, len(boxes))
def cal_block_index(fix_blocks, sorted_bboxes):
for block in fix_blocks:
line_index_list = []
if len(block['lines']) == 0:
block['index'] = sorted_bboxes.index(block['bbox'])
else:
for line in block['lines']:
line['index'] = sorted_bboxes.index(line['bbox'])
line_index_list.append(line['index'])
median_value = statistics.median(line_index_list)
block['index'] = median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
return fix_blocks
def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
# block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
x0, y0, x1, y1 = block_bbox
block_height = y1 - y0
block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox
if line_height * 3 < block_height:
if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
lines = int(block_height / line_height) + 1
else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3
lines = 3
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height) + 1
else: # 判断长宽比
if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2
lines = 2
# 确定从哪个y位置开始绘制线条
current_y = y0
# 用于存储线条的位置信息[(x0, y), ...]
lines_positions = []
for i in range(lines):
lines_positions.append([x0, current_y, x1, current_y + line_height])
current_y += line_height
return lines_positions
else:
return [[x0, y0, x1, y1]]
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = []
for block in fix_blocks:
if block['type'] in [
BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
if len(block['lines']) == 0:
bbox = block['bbox']
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
for line in lines:
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)
else:
for line in block['lines']:
bbox = line['bbox']
page_line_list.append(bbox)
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
bbox = block['bbox']
block["real_lines"] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
block['lines'] = []
for line in lines:
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)
# 使用layoutreader排序
x_scale = 1000.0 / page_w
y_scale = 1000.0 / page_h
boxes = []
# logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
for left, top, right, bottom in page_line_list:
if left < 0:
logger.warning(
f'left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
left = 0
if right > page_w:
logger.warning(
f'right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
right = page_w
if top < 0:
logger.warning(
f'top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
top = 0
if bottom > page_h:
logger.warning(
f'bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}'
) # noqa: E501
bottom = page_h
left = round(left * x_scale)
top = round(top * y_scale)
right = round(right * x_scale)
bottom = round(bottom * y_scale)
assert (
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
), f'Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}' # noqa: E126, E121
boxes.append([left, top, right, bottom])
model_manager = ModelSingleton()
model = model_manager.get_model('layoutreader')
with torch.no_grad():
orders = do_predict(boxes, model)
sorted_bboxes = [page_line_list[i] for i in orders]
return sorted_bboxes
def get_line_height(blocks):
page_line_height_list = []
for block in blocks:
if block['type'] in [
BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
return statistics.median(page_line_height_list)
else:
return 10
def process_groups(groups, body_key, caption_key, footnote_key):
body_blocks = []
caption_blocks = []
footnote_blocks = []
for i, group in enumerate(groups):
group[body_key]['group_id'] = i
body_blocks.append(group[body_key])
for caption_block in group[caption_key]:
caption_block['group_id'] = i
caption_blocks.append(caption_block)
for footnote_block in group[footnote_key]:
footnote_block['group_id'] = i
footnote_blocks.append(footnote_block)
return body_blocks, caption_blocks, footnote_blocks
def process_block_list(blocks, body_type, block_type):
indices = [block['index'] for block in blocks]
median_index = statistics.median(indices)
body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
return {
'type': block_type,
'bbox': body_bbox,
'blocks': blocks,
'index': median_index,
}
def revert_group_blocks(blocks):
image_groups = {}
table_groups = {}
new_blocks = []
for block in blocks:
if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
group_id = block['group_id']
if group_id not in image_groups:
image_groups[group_id] = []
image_groups[group_id].append(block)
elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
group_id = block['group_id']
if group_id not in table_groups:
table_groups[group_id] = []
table_groups[group_id].append(block)
else:
new_blocks.append(block)
for group_id, blocks in image_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
for group_id, blocks in table_groups.items():
new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
return new_blocks
def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
def get_block_bboxes(blocks, block_type_list):
return [block[0:4] for block in blocks if block[7] in block_type_list]
image_bboxes = get_block_bboxes(all_bboxes, [BlockType.ImageBody])
table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TableBody])
other_block_type = []
for block_type in BlockType.__dict__.values():
if not isinstance(block_type, str):
continue
if block_type not in [BlockType.ImageBody, BlockType.TableBody]:
other_block_type.append(block_type)
other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.Discarded])
new_spans = []
for span in spans:
span_bbox = span['bbox']
span_type = span['type']
if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
discarded_block_bboxes):
new_spans.append(span)
continue
if span_type == ContentType.Image:
if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
image_bboxes):
new_spans.append(span)
elif span_type == ContentType.Table:
if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
table_bboxes):
new_spans.append(span)
else:
if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
other_block_bboxes):
new_spans.append(span)
return new_spans
def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
):
need_drop = False
drop_reason = []
"""从magic_model对象中获取后面会用到的区块信息"""
# img_blocks = magic_model.get_imgs(page_id)
# table_blocks = magic_model.get_tables(page_id)
img_groups = magic_model.get_imgs_v2(page_id)
table_groups = magic_model.get_tables_v2(page_id)
img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
)
discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id)
"""将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = []
if len(interline_equation_blocks) > 0:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
else:
all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
spans = magic_model.get_all_spans(page_id)
"""根据parse_mode,构造spans"""
if parse_mode == SupportedPdfParseMethod.TXT:
"""ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
spans = replace_text_span(pymu_spans, spans)
elif parse_mode == SupportedPdfParseMethod.OCR:
pass
else:
raise Exception('parse_mode must be txt or ocr')
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""对image和table截图"""
spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
"""如果当前页面没有bbox则跳过"""
if len(all_bboxes) == 0:
logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2(
[],
[],
page_id,
page_w,
page_h,
[],
[],
[],
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
"""将span填入blocks中"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
"""对block进行fix操作"""
fix_blocks = fix_block_spans_v2(block_with_spans)
"""获取所有line并计算正文line的高度"""
line_height = get_line_height(fix_blocks)
"""获取所有line并对line排序"""
sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
"""根据line的中位数算block的序列关系"""
fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
"""将image和table的block还原回group形式参与后续流程"""
fix_blocks = revert_group_blocks(fix_blocks)
"""重排block"""
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
"""获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
"""构造pdf_info_dict"""
page_info = ocr_construct_page_component_v2(
sorted_blocks,
[],
page_id,
page_w,
page_h,
[],
images,
tables,
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
return page_info
def pdf_parse_union(
dataset: Dataset,
model_list,
imageWriter,
parse_mode,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
pdf_bytes_md5 = compute_md5(dataset.data_bits())
"""初始化空的pdf_info_dict"""
pdf_info_dict = {}
"""用model_list和docs对象初始化magic_model"""
magic_model = MagicModel(model_list, dataset)
"""根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1
)
if end_page_id > len(dataset) - 1:
logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(dataset) - 1
"""初始化启动时间"""
start_time = time.time()
for page_id, page in enumerate(dataset):
"""debug时输出每页解析的耗时."""
if debug_mode:
time_now = time.time()
logger.info(
f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
)
start_time = time_now
"""解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
)
else:
page_info = page.get_page_info()
page_w = page_info.w
page_h = page_info.h
page_info = ocr_construct_page_component_v2(
[], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
)
pdf_info_dict[f'page_{page_id}'] = page_info
"""分段"""
para_split(pdf_info_dict, debug_mode=debug_mode)
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
'pdf_info': pdf_info_list,
}
clean_memory()
return new_pdf_info_dict
if __name__ == '__main__':
pass
...@@ -17,7 +17,7 @@ class AbsPipe(ABC): ...@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT = "txt" PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes self.pdf_bytes = pdf_bytes
self.model_list = model_list self.model_list = model_list
self.image_writer = image_writer self.image_writer = image_writer
...@@ -25,6 +25,10 @@ class AbsPipe(ABC): ...@@ -25,6 +25,10 @@ class AbsPipe(ABC):
self.is_debug = is_debug self.is_debug = is_debug
self.start_page_id = start_page_id self.start_page_id = start_page_id
self.end_page_id = end_page_id self.end_page_id = end_page_id
self.lang = lang
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def get_compress_pdf_mid_data(self): def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data) return JsonCompressor.compress_json(self.pdf_mid_data)
......
...@@ -10,19 +10,25 @@ from magic_pdf.user_api import parse_ocr_pdf ...@@ -10,19 +10,25 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id) layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -11,19 +11,25 @@ from magic_pdf.user_api import parse_txt_pdf ...@@ -11,19 +11,25 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe): class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id) layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf ...@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe): class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None): start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
self.pdf_type = jso_useful_key["_pdf_type"] self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id) super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
lang, layout_model, formula_enable, table_enable)
if len(self.model_list) == 0: if len(self.model_list) == 0:
self.input_model_is_empty = True self.input_model_is_empty = True
else: else:
...@@ -28,22 +30,29 @@ class UNIPipe(AbsPipe): ...@@ -28,22 +30,29 @@ class UNIPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty, is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id) start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info("uni_pipe mk content list finished") logger.info("uni_pipe mk content list finished")
return result return result
......
from loguru import logger from loguru import logger
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \ from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
calculate_iou calculate_iou, calculate_vertical_projection_overlap_ratio
from magic_pdf.libs.drop_tag import DropTag from magic_pdf.libs.drop_tag import DropTag
from magic_pdf.libs.ocr_content_type import BlockType from magic_pdf.libs.ocr_content_type import BlockType
from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
...@@ -60,6 +60,88 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc ...@@ -60,6 +60,88 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
return all_bboxes, all_discarded_blocks, drop_reasons return all_bboxes, all_discarded_blocks, drop_reasons
def add_bboxes(blocks, block_type, bboxes):
for block in blocks:
x0, y0, x1, y1 = block['bbox']
if block_type in [
BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
]:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
else:
bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
def ocr_prepare_bboxes_for_layout_split_v2(
img_body_blocks, img_caption_blocks, img_footnote_blocks,
table_body_blocks, table_caption_blocks, table_footnote_blocks,
discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
):
all_bboxes = []
add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
add_bboxes(text_blocks, BlockType.Text, all_bboxes)
add_bboxes(title_blocks, BlockType.Title, all_bboxes)
add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
'''block嵌套问题解决'''
'''文本框与标题框重叠,优先信任文本框'''
all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
'''任何框体与舍弃框重叠,优先信任舍弃框'''
all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
# interline_equation 与title或text框冲突的情况,分两种情况处理
'''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes)
'''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
# 通过后续大框套小框逻辑删除
'''discarded_blocks'''
all_discarded_blocks = []
add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
'''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
footnote_blocks = []
for discarded in discarded_blocks:
x0, y0, x1, y1 = discarded['bbox']
if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
footnote_blocks.append([x0, y0, x1, y1])
'''移除在footnote下面的任何框'''
need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
if len(need_remove_blocks) > 0:
for block in need_remove_blocks:
all_bboxes.remove(block)
all_discarded_blocks.append(block)
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes = remove_overlaps_min_blocks(all_bboxes)
all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
'''将剩余的bbox做分离处理,防止后面分layout时出错'''
all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
return all_bboxes, all_discarded_blocks
def find_blocks_under_footnote(all_bboxes, footnote_blocks):
need_remove_blocks = []
for block in all_bboxes:
block_x0, block_y0, block_x1, block_y1 = block[:4]
for footnote_bbox in footnote_blocks:
footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
if block not in need_remove_blocks:
need_remove_blocks.append(block)
break
return need_remove_blocks
def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes): def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
# 先提取所有text和interline block # 先提取所有text和interline block
text_blocks = [] text_blocks = []
......
...@@ -49,8 +49,7 @@ def merge_spans_to_line(spans): ...@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
continue continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行 # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.5):
current_line[-1]['bbox']):
current_line.append(span) current_line.append(span)
else: else:
# 否则,开始新行 # 否则,开始新行
...@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio): ...@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
'type': block_type, 'type': block_type,
'bbox': block_bbox, 'bbox': block_bbox,
} }
if block_type in [
BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
]:
block_dict["group_id"] = block[-1]
block_spans = [] block_spans = []
for span in spans: for span in spans:
span_bbox = span['bbox'] span_bbox = span['bbox']
...@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks): ...@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
return fix_blocks return fix_blocks
def fix_block_spans_v2(block_with_spans):
"""1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
fix_blocks = []
for block in block_with_spans:
block_type = block['type']
if block_type in [BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
block = fix_text_block(block)
elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
block = fix_interline_block(block)
else:
continue
fix_blocks.append(block)
return fix_blocks
def fix_discarded_block(discarded_block_with_spans): def fix_discarded_block(discarded_block_with_spans):
fix_discarded_blocks = [] fix_discarded_blocks = []
for block in discarded_block_with_spans: for block in discarded_block_with_spans:
......
...@@ -2,13 +2,13 @@ model: ...@@ -2,13 +2,13 @@ model:
arch: unimernet arch: unimernet
model_type: unimernet model_type: unimernet
model_config: model_config:
model_name: ./models model_name: ./models/unimernet_base
max_seq_len: 1024 max_seq_len: 1536
length_aware: False
load_pretrained: True load_pretrained: True
pretrained: ./models/pytorch_model.bin pretrained: './models/unimernet_base/pytorch_model.pth'
tokenizer_config: tokenizer_config:
path: ./models path: ./models/unimernet_base
datasets: datasets:
formula_rec_eval: formula_rec_eval:
...@@ -18,7 +18,7 @@ datasets: ...@@ -18,7 +18,7 @@ datasets:
image_size: image_size:
- 192 - 192
- 672 - 672
run: run:
runner: runner_iter runner: runner_iter
task: unimernet_train task: unimernet_train
...@@ -43,4 +43,4 @@ run: ...@@ -43,4 +43,4 @@ run:
distributed_type: ddp # or fsdp when train llm distributed_type: ddp # or fsdp when train llm
generate_cfg: generate_cfg:
temperature: 0.0 temperature: 0.0
\ No newline at end of file
config:
device: cpu
layout: True
formula: True
table_config:
model: TableMaster
is_table_recog_enable: False
max_time: 400
weights: weights:
layout: Layout/model_final.pth layoutlmv3: Layout/LayoutLMv3/model_final.pth
mfd: MFD/weights.pt doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
mfr: MFR/UniMERNet yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
TableMaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
\ No newline at end of file \ No newline at end of file
...@@ -44,6 +44,18 @@ auto: automatically choose the best method for parsing pdf from ocr and txt. ...@@ -44,6 +44,18 @@ auto: automatically choose the best method for parsing pdf from ocr and txt.
without method specified, auto will be used by default.""", without method specified, auto will be used by default.""",
default='auto', default='auto',
) )
@click.option(
'-l',
'--lang',
'lang',
type=str,
help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
You should input "Abbreviation" with language form url:
https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
""",
default=None,
)
@click.option( @click.option(
'-d', '-d',
'--debug', '--debug',
...@@ -68,7 +80,7 @@ without method specified, auto will be used by default.""", ...@@ -68,7 +80,7 @@ without method specified, auto will be used by default.""",
help='The ending page for PDF parsing, beginning from 0.', help='The ending page for PDF parsing, beginning from 0.',
default=None, default=None,
) )
def cli(path, output_dir, method, debug_able, start_page_id, end_page_id): def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
model_config.__use_inside_model__ = True model_config.__use_inside_model__ = True
model_config.__model_mode__ = 'full' model_config.__model_mode__ = 'full'
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
...@@ -90,6 +102,7 @@ def cli(path, output_dir, method, debug_able, start_page_id, end_page_id): ...@@ -90,6 +102,7 @@ def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
debug_able, debug_able,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
lang=lang
) )
except Exception as e: except Exception as e:
......
...@@ -6,8 +6,8 @@ import click ...@@ -6,8 +6,8 @@ import click
from loguru import logger from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox, from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
drow_model_bbox) draw_model_bbox, draw_span_bbox)
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.pipe.OCRPipe import OCRPipe
from magic_pdf.pipe.TXTPipe import TXTPipe from magic_pdf.pipe.TXTPipe import TXTPipe
...@@ -39,16 +39,21 @@ def do_parse( ...@@ -39,16 +39,21 @@ def do_parse(
f_dump_middle_json=True, f_dump_middle_json=True,
f_dump_model_json=True, f_dump_model_json=True,
f_dump_orig_pdf=True, f_dump_orig_pdf=True,
f_dump_content_list=False, f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD, f_make_md_mode=MakeMode.MM_MD,
f_draw_model_bbox=False, f_draw_model_bbox=False,
f_draw_line_sort_bbox=False,
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
): ):
if debug_able: if debug_able:
logger.warning('debug mode is on') logger.warning('debug mode is on')
f_dump_content_list = True
f_draw_model_bbox = True f_draw_model_bbox = True
f_draw_line_sort_bbox = True
orig_model_list = copy.deepcopy(model_list) orig_model_list = copy.deepcopy(model_list)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
...@@ -61,13 +66,16 @@ def do_parse( ...@@ -61,13 +66,16 @@ def do_parse(
if parse_method == 'auto': if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list} jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True, pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'txt': elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True, pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'ocr': elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True, pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id) start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
else: else:
logger.error('unknown parse method') logger.error('unknown parse method')
exit(1) exit(1)
...@@ -89,7 +97,9 @@ def do_parse( ...@@ -89,7 +97,9 @@ def do_parse(
if f_draw_span_bbox: if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
if f_draw_model_bbox: if f_draw_model_bbox:
drow_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name) draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
if f_draw_line_sort_bbox:
draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
md_content = pipe.pipe_mk_markdown(image_dir, md_content = pipe.pipe_mk_markdown(image_dir,
drop_mode=DropMode.NONE, drop_mode=DropMode.NONE,
......
...@@ -26,7 +26,7 @@ PARSE_TYPE_OCR = "ocr" ...@@ -26,7 +26,7 @@ PARSE_TYPE_OCR = "ocr"
def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
start_page_id=0, end_page_id=None, start_page_id=0, end_page_id=None, lang=None,
*args, **kwargs): *args, **kwargs):
""" """
解析文本类pdf 解析文本类pdf
...@@ -44,11 +44,14 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit ...@@ -44,11 +44,14 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
pdf_info_dict["_version_name"] = __version__ pdf_info_dict["_version_name"] = __version__
if lang is not None:
pdf_info_dict["_lang"] = lang
return pdf_info_dict return pdf_info_dict
def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
start_page_id=0, end_page_id=None, start_page_id=0, end_page_id=None, lang=None,
*args, **kwargs): *args, **kwargs):
""" """
解析ocr类pdf 解析ocr类pdf
...@@ -66,12 +69,15 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit ...@@ -66,12 +69,15 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
pdf_info_dict["_version_name"] = __version__ pdf_info_dict["_version_name"] = __version__
if lang is not None:
pdf_info_dict["_lang"] = lang
return pdf_info_dict return pdf_info_dict
def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
input_model_is_empty: bool = False, input_model_is_empty: bool = False,
start_page_id=0, end_page_id=None, start_page_id=0, end_page_id=None, lang=None,
*args, **kwargs): *args, **kwargs):
""" """
ocr和文本混合的pdf,全部解析出来 ocr和文本混合的pdf,全部解析出来
...@@ -95,9 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr ...@@ -95,9 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False): if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr") logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
if input_model_is_empty: if input_model_is_empty:
pdf_models = doc_analyze(pdf_bytes, ocr=True, layout_model = kwargs.get("layout_model", None)
start_page_id=start_page_id, formula_enable = kwargs.get("formula_enable", None)
end_page_id=end_page_id) table_enable = kwargs.get("table_enable", None)
pdf_models = doc_analyze(
pdf_bytes,
ocr=True,
start_page_id=start_page_id,
end_page_id=end_page_id,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pdf_info_dict = parse_pdf(parse_pdf_by_ocr) pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None: if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.") raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
...@@ -108,4 +124,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr ...@@ -108,4 +124,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
pdf_info_dict["_version_name"] = __version__ pdf_info_dict["_version_name"] = __version__
if lang is not None:
pdf_info_dict["_lang"] = lang
return pdf_info_dict return pdf_info_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment