Commit 309be741 authored by myhloli's avatar myhloli
Browse files

refactor(txt_parse): improve text extraction accuracy with new algorithm

- Implement new text extraction method (txt_spans_extract_v2) to enhance accuracy
- Add character filling in spans for better text reconstruction
- Introduce empty span handling using OCR for missed text
- Optimize span filtering and overlap removal
parent e52bd023
from io import BytesIO
import cv2
import numpy as np
from PIL import Image
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.hash_utils import compute_sha256
......@@ -29,3 +32,26 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
imageWriter.write(img_hash256_path, byte_data)
return img_hash256_path
def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
# 配置缩放倍数为3倍
zoom = fitz.Matrix(3, 3)
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
pil_image = Image.open(image_file)
if mode == "cv2":
image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
elif mode == "pillow":
image_result = pil_image
else:
raise ValueError(f"mode: {mode} is not supported.")
return image_result
\ No newline at end of file
......@@ -63,7 +63,7 @@ def ocr_model_init(show_log: bool = False,
use_dilation=True,
det_db_unclip_ratio=1.8,
):
if lang is not None:
if lang is not None and lang != '':
model = ModifiedPaddleOCR(
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
......
......@@ -9,6 +9,7 @@ def parse_pdf_by_ocr(pdf_bytes,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
......@@ -18,4 +19,5 @@ def parse_pdf_by_ocr(pdf_bytes,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
......@@ -10,6 +10,7 @@ def parse_pdf_by_txt(
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
......@@ -19,4 +20,5 @@ def parse_pdf_by_txt(
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
......@@ -18,7 +18,21 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.local_math import float_equal
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.model.magic_model import MagicModel
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import torchtext
if torchtext.__version__ >= "0.18.0":
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import \
......@@ -74,7 +88,150 @@ def __replace_STX_ETX(text_str: str):
return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations):
def chars_to_content(span):
# # 先给chars按char['bbox']的x坐标排序
# span['chars'] = sorted(span['chars'], key=lambda x: x['bbox'][0])
# 先给chars按char['bbox']的中心点的x坐标排序
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
content = ''
# 求char的平均宽度
if len(span['chars']) == 0:
span['content'] = content
del span['chars']
return
else:
char_width_sum = sum([char['bbox'][2] - char['bbox'][0] for char in span['chars']])
char_avg_width = char_width_sum / len(span['chars'])
for char in span['chars']:
# 如果下一个char的x0和上一个char的x1距离超过一个字符宽度,则需要在中间插入一个空格
if char['bbox'][0] - span['chars'][span['chars'].index(char) - 1]['bbox'][2] > char_avg_width:
content += ' '
content += char['c']
span['content'] = __replace_STX_ETX(content)
del span['chars']
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',')
def fill_char_in_spans(spans, all_chars):
for char in all_chars:
for span in spans:
# 判断char是否属于LINE_STOP_FLAG
if char['c'] in LINE_STOP_FLAG:
char_is_line_stop_flag = True
else:
char_is_line_stop_flag = False
if calculate_char_in_span(char['bbox'], span['bbox'], char_is_line_stop_flag):
span['chars'].append(char)
break
for span in spans:
chars_to_content(span)
# 使用鲁棒性更强的中心点坐标判断
def calculate_char_in_span(char_bbox, span_bbox, char_is_line_stop_flag):
char_center_x = (char_bbox[0] + char_bbox[2]) / 2
char_center_y = (char_bbox[1] + char_bbox[3]) / 2
span_center_y = (span_bbox[1] + span_bbox[3]) / 2
span_height = span_bbox[3] - span_bbox[1]
if (
span_bbox[0] < char_center_x < span_bbox[2]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height / 4 # 字符的中轴和span的中轴高度差不能超过1/4span高度
):
return True
else:
# 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
# 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
if char_is_line_stop_flag:
if (
(span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height / 4
):
return True
else:
return False
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
useful_spans = []
unuseful_spans = []
for span in spans:
for block in all_bboxes:
if block[7] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
continue
else:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
useful_spans.append(span)
break
for block in all_discarded_blocks:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
unuseful_spans.append(span)
break
text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
# @todo: 拿到char之后把倾斜角度较大的先删一遍
all_pymu_chars = []
for block in text_blocks:
for line in block['lines']:
for span in line['spans']:
all_pymu_chars.extend(span['chars'])
new_spans = []
for span in useful_spans:
if span['type'] in [ContentType.Text]:
span['chars'] = []
new_spans.append(span)
for span in unuseful_spans:
if span['type'] in [ContentType.Text]:
span['chars'] = []
new_spans.append(span)
fill_char_in_spans(new_spans, all_pymu_chars)
empty_spans = []
for span in new_spans:
if len(span['content']) == 0:
empty_spans.append(span)
if len(empty_spans) > 0:
# 初始化ocr模型
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name="ocr",
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
for span in empty_spans:
spans.remove(span)
# 对span的bbox截图
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode="cv2")
ocr_res = ocr_model.ocr(span_img, det=False)
# logger.info(f"ocr_res: {ocr_res}")
# logger.info(f"empty_span: {span}")
if len(ocr_res) > 0:
if len(ocr_res[0]) > 0:
ocr_text, ocr_score = ocr_res[0][0]
if ocr_score > 0.5 and len(ocr_text) > 0:
span['content'] = ocr_text
spans.append(span)
return spans
def txt_spans_extract_v1(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
'blocks'
......@@ -464,18 +621,16 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
):
need_drop = False
drop_reason = []
"""从magic_model对象中获取后面会用到的区块信息"""
# img_blocks = magic_model.get_imgs(page_id)
# table_blocks = magic_model.get_tables(page_id)
img_groups = magic_model.get_imgs_v2(page_id)
table_groups = magic_model.get_tables_v2(page_id)
"""对image和table的区块分组"""
img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
......@@ -519,38 +674,20 @@ def parse_page_core(
page_h,
)
"""获取所有的spans信息"""
spans = magic_model.get_all_spans(page_id)
"""根据parse_mode,构造spans"""
if parse_mode == SupportedPdfParseMethod.TXT:
"""ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
spans = replace_text_span(pymu_spans, spans)
elif parse_mode == SupportedPdfParseMethod.OCR:
pass
else:
raise Exception('parse_mode must be txt or ocr')
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""对image和table截图"""
spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
"""如果当前页面没有bbox则跳过"""
"""如果当前页面没有有效的bbox则跳过"""
if len(all_bboxes) == 0:
logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2(
......@@ -568,7 +705,32 @@ def parse_page_core(
drop_reason,
)
"""将span填入blocks中"""
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
if parse_mode == SupportedPdfParseMethod.TXT:
"""之前的公式替换方案"""
# pymu_spans = txt_spans_extract_v1(page_doc, inline_equations, interline_equations)
# spans = replace_text_span(pymu_spans, spans)
"""ocr 中文本类的 span 用 pymu spans 替换!"""
spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
elif parse_mode == SupportedPdfParseMethod.OCR:
pass
else:
raise Exception('parse_mode must be txt or ocr')
"""对image和table截图"""
spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
"""span填充进block"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
"""对block进行fix操作"""
......@@ -618,6 +780,7 @@ def pdf_parse_union(
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
pdf_bytes_md5 = compute_md5(dataset.data_bits())
......@@ -654,7 +817,7 @@ def pdf_parse_union(
"""解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
)
else:
page_info = page.get_page_info()
......
......@@ -30,6 +30,7 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug,
lang=lang,
)
pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
......@@ -53,6 +54,7 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug,
lang=lang,
)
pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
......@@ -80,6 +82,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug,
lang=lang,
)
except Exception as e:
logger.exception(e)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment