Commit d2de6d80 authored by myhloli's avatar myhloli
Browse files

refactor: update text span extraction to use new version and improve character handling

parent 1ed61cb5
...@@ -9,7 +9,7 @@ from mineru.utils.model_utils import clean_memory ...@@ -9,7 +9,7 @@ from mineru.utils.model_utils import clean_memory
from mineru.utils.pipeline_magic_model import MagicModel from mineru.utils.pipeline_magic_model import MagicModel
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \ from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, txt_spans_extract_v2 remove_overlaps_min_spans, txt_spans_extract_v3
from mineru.version import __version__ from mineru.version import __version__
from mineru.utils.hash_utils import str_md5 from mineru.utils.hash_utils import str_md5
...@@ -79,7 +79,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer ...@@ -79,7 +79,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
pass pass
else: else:
"""使用新版本的混合ocr方案.""" """使用新版本的混合ocr方案."""
spans = txt_spans_extract_v2(page, spans, page_pil_img, scale) spans = txt_spans_extract_v3(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
"""先处理不需要排版的discarded_blocks""" """先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks( discarded_block_with_spans, spans = fill_spans_in_blocks(
......
...@@ -215,8 +215,8 @@ def do_parse( ...@@ -215,8 +215,8 @@ def do_parse(
if __name__ == "__main__": if __name__ == "__main__":
pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf" # pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf"
# pdf_path = "../../demo/pdfs/demo1.pdf" pdf_path = "../../demo/pdfs/demo1.pdf"
with open(pdf_path, "rb") as f: with open(pdf_path, "rb") as f:
try: try:
do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"], end_page_id=20,) do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"], end_page_id=20,)
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import re import re
import statistics
import cv2 import cv2
import numpy as np import numpy as np
from loguru import logger from loguru import logger
...@@ -116,6 +118,7 @@ def __replace_unicode(text: str): ...@@ -116,6 +118,7 @@ def __replace_unicode(text: str):
return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text) return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
"""textpage.get_text_bounded方案"""
def txt_spans_extract_v1(pdf_page, spans, pil_img, scale): def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
textpage = pdf_page.get_textpage() textpage = pdf_page.get_textpage()
...@@ -162,6 +165,7 @@ def txt_spans_extract_v1(pdf_page, spans, pil_img, scale): ...@@ -162,6 +165,7 @@ def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
return spans return spans
"""pdf_text dict方案 span级别"""
def txt_spans_extract_v2(pdf_page, spans, pil_img, scale): def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
page_dict = get_page(pdf_page) page_dict = get_page(pdf_page)
...@@ -224,6 +228,243 @@ def txt_spans_extract_v2(pdf_page, spans, pil_img, scale): ...@@ -224,6 +228,243 @@ def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
return spans return spans
"""pdf_text dict方案 char级别"""
def txt_spans_extract_v3(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
page_dict = get_page(pdf_page)
page_all_chars = []
page_all_lines = []
for block in page_dict['blocks']:
for line in block['lines']:
if 0 < abs(line['rotation']) < 90:
# 旋转角度在0-90度之间的行,直接跳过
continue
page_all_lines.append(line)
for span in line['spans']:
for char in span['chars']:
page_all_chars.append(char)
# 计算所有sapn的高度的中位数
span_height_list = []
for span in spans:
if span['type'] in [ContentType.TEXT]:
span_height = span['bbox'][3] - span['bbox'][1]
span['height'] = span_height
span['width'] = span['bbox'][2] - span['bbox'][0]
span_height_list.append(span_height)
if len(span_height_list) == 0:
return spans
else:
median_span_height = statistics.median(span_height_list)
useful_spans = []
unuseful_spans = []
# 纵向span的两个特征:1. 高度超过多个line 2. 高宽比超过某个值
vertical_spans = []
for span in spans:
if span['type'] in [ContentType.TEXT]:
for block in all_bboxes + all_discarded_blocks:
if block[7] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
continue
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
if span['height'] > median_span_height * 3 and span['height'] > span['width'] * 3:
vertical_spans.append(span)
elif block in all_bboxes:
useful_spans.append(span)
else:
unuseful_spans.append(span)
break
"""垂直的span框直接用line进行填充"""
if len(vertical_spans) > 0:
for pdfium_line in page_all_lines:
for span in vertical_spans:
if calculate_overlap_area_in_bbox1_area_ratio(pdfium_line['bbox'].bbox, span['bbox']) > 0.5:
for pdfium_span in pdfium_line['spans']:
span['content'] += pdfium_span['text']
break
for span in vertical_spans:
if len(span['content']) == 0:
spans.remove(span)
"""水平的span框先用char填充,再用ocr填充空的span框"""
new_spans = []
for span in useful_spans + unuseful_spans:
if span['type'] in [ContentType.TEXT]:
span['chars'] = []
new_spans.append(span)
need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars)
"""对未填充的span进行ocr"""
if len(need_ocr_spans) > 0:
for span in need_ocr_spans:
# 对span的bbox截图再ocr
span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
# 计算span的对比度,低于0.20的span不进行ocr
if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
spans.remove(span)
continue
span['content'] = ''
span['score'] = 1.0
span['np_img'] = span_img
return spans
def fill_char_in_spans(spans, all_chars):
# 简单从上到下排一下序
spans = sorted(spans, key=lambda x: x['bbox'][1])
for char in all_chars:
for span in spans:
if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
span['chars'].append(char)
break
need_ocr_spans = []
for span in spans:
chars_to_content(span)
# 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
if len(span['content']) * span['height'] < span['width'] * 0.5:
# logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
need_ocr_spans.append(span)
del span['height'], span['width']
return need_ocr_spans
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
LINE_START_FLAG = ('(', '(', '"', '“', '【', '{', '《', '<', '「', '『', '【', '[',)
def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
char_center_x = (char_bbox[0] + char_bbox[2]) / 2
char_center_y = (char_bbox[1] + char_bbox[3]) / 2
span_center_y = (span_bbox[1] + span_bbox[3]) / 2
span_height = span_bbox[3] - span_bbox[1]
if (
span_bbox[0] < char_center_x < span_bbox[2]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height * span_height_radio # 字符的中轴和span的中轴高度差不能超过1/4span高度
):
return True
else:
# 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
# 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
if char in LINE_STOP_FLAG:
if (
(span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
and char_center_x > span_bbox[0]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height * span_height_radio
):
return True
elif char in LINE_START_FLAG:
if (
span_bbox[0] < char_bbox[2] < (span_bbox[0] + span_height)
and char_center_x < span_bbox[2]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height * span_height_radio
):
return True
else:
return False
def chars_to_content(span):
# 检查span中的char是否为空
if len(span['chars']) == 0:
pass
else:
# 先给chars按char['bbox']的中心点的x坐标排序
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
# Calculate the width of each character
char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
# Calculate the median width
median_width = statistics.median(char_widths)
# 通过x轴重叠比率移除一部分char
span = remove_x_overlapping_chars(span, median_width)
content = ''
for char in span['chars']:
# 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
char1 = char
char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['char'] != ' ' and char2['char'] != ' ':
content += f"{char['char']} "
else:
content += char['char']
content = __replace_unicode(content)
content = __replace_ligatures(content)
content = __replace_ligatures(content)
span['content'] = content.strip()
del span['chars']
def remove_x_overlapping_chars(span, median_width):
"""
Remove characters from a span that overlap significantly on the x-axis.
Args:
median_width:
span (dict): A span containing a list of chars, each with bbox coordinates
in the format [x0, y0, x1, y1]
Returns:
dict: The span with overlapping characters removed
"""
if 'chars' not in span or len(span['chars']) < 2:
return span
overlap_threshold = median_width * 0.3
i = 0
while i < len(span['chars']) - 1:
char1 = span['chars'][i]
char2 = span['chars'][i + 1]
# Calculate overlap width
x_left = max(char1['bbox'][0], char2['bbox'][0])
x_right = min(char1['bbox'][2], char2['bbox'][2])
if x_right > x_left: # There is overlap
overlap_width = x_right - x_left
if overlap_width > overlap_threshold:
if char1['char'] == char2['char'] or char1['char'] == ' ' or char2['char'] == ' ':
# Determine which character to remove
width1 = char1['bbox'][2] - char1['bbox'][0]
width2 = char2['bbox'][2] - char2['bbox'][0]
if width1 < width2:
# Remove the narrower character
span['chars'].pop(i)
else:
span['chars'].pop(i + 1)
else:
i += 1
# Don't increment i since we need to check the new pair
else:
i += 1
else:
i += 1
return span
def calculate_contrast(img, img_mode) -> float: def calculate_contrast(img, img_mode) -> float:
""" """
计算给定图像的对比度。 计算给定图像的对比度。
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment