Commit d2de6d80 authored by myhloli's avatar myhloli
Browse files

refactor: update text span extraction to use new version and improve character handling

parent 1ed61cb5
......@@ -9,7 +9,7 @@ from mineru.utils.model_utils import clean_memory
from mineru.utils.pipeline_magic_model import MagicModel
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, txt_spans_extract_v2
remove_overlaps_min_spans, txt_spans_extract_v3
from mineru.version import __version__
from mineru.utils.hash_utils import str_md5
......@@ -79,7 +79,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
pass
else:
"""使用新版本的混合ocr方案."""
spans = txt_spans_extract_v2(page, spans, page_pil_img, scale)
spans = txt_spans_extract_v3(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
......
......@@ -215,8 +215,8 @@ def do_parse(
if __name__ == "__main__":
pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf"
# pdf_path = "../../demo/pdfs/demo1.pdf"
# pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf"
pdf_path = "../../demo/pdfs/demo1.pdf"
with open(pdf_path, "rb") as f:
try:
do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"], end_page_id=20,)
......
# Copyright (c) Opendatalab. All rights reserved.
import re
import statistics
import cv2
import numpy as np
from loguru import logger
......@@ -116,6 +118,7 @@ def __replace_unicode(text: str):
return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
"""textpage.get_text_bounded方案"""
def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
textpage = pdf_page.get_textpage()
......@@ -162,6 +165,7 @@ def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
return spans
"""pdf_text dict方案 span级别"""
def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
page_dict = get_page(pdf_page)
......@@ -224,6 +228,243 @@ def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
return spans
"""pdf_text dict方案 char级别"""
def txt_spans_extract_v3(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
page_dict = get_page(pdf_page)
page_all_chars = []
page_all_lines = []
for block in page_dict['blocks']:
for line in block['lines']:
if 0 < abs(line['rotation']) < 90:
# 旋转角度在0-90度之间的行,直接跳过
continue
page_all_lines.append(line)
for span in line['spans']:
for char in span['chars']:
page_all_chars.append(char)
# 计算所有sapn的高度的中位数
span_height_list = []
for span in spans:
if span['type'] in [ContentType.TEXT]:
span_height = span['bbox'][3] - span['bbox'][1]
span['height'] = span_height
span['width'] = span['bbox'][2] - span['bbox'][0]
span_height_list.append(span_height)
if len(span_height_list) == 0:
return spans
else:
median_span_height = statistics.median(span_height_list)
useful_spans = []
unuseful_spans = []
# 纵向span的两个特征:1. 高度超过多个line 2. 高宽比超过某个值
vertical_spans = []
for span in spans:
if span['type'] in [ContentType.TEXT]:
for block in all_bboxes + all_discarded_blocks:
if block[7] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
continue
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
if span['height'] > median_span_height * 3 and span['height'] > span['width'] * 3:
vertical_spans.append(span)
elif block in all_bboxes:
useful_spans.append(span)
else:
unuseful_spans.append(span)
break
"""垂直的span框直接用line进行填充"""
if len(vertical_spans) > 0:
for pdfium_line in page_all_lines:
for span in vertical_spans:
if calculate_overlap_area_in_bbox1_area_ratio(pdfium_line['bbox'].bbox, span['bbox']) > 0.5:
for pdfium_span in pdfium_line['spans']:
span['content'] += pdfium_span['text']
break
for span in vertical_spans:
if len(span['content']) == 0:
spans.remove(span)
"""水平的span框先用char填充,再用ocr填充空的span框"""
new_spans = []
for span in useful_spans + unuseful_spans:
if span['type'] in [ContentType.TEXT]:
span['chars'] = []
new_spans.append(span)
need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars)
"""对未填充的span进行ocr"""
if len(need_ocr_spans) > 0:
for span in need_ocr_spans:
# 对span的bbox截图再ocr
span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
# 计算span的对比度,低于0.20的span不进行ocr
if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
spans.remove(span)
continue
span['content'] = ''
span['score'] = 1.0
span['np_img'] = span_img
return spans
def fill_char_in_spans(spans, all_chars):
# 简单从上到下排一下序
spans = sorted(spans, key=lambda x: x['bbox'][1])
for char in all_chars:
for span in spans:
if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
span['chars'].append(char)
break
need_ocr_spans = []
for span in spans:
chars_to_content(span)
# 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
if len(span['content']) * span['height'] < span['width'] * 0.5:
# logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
need_ocr_spans.append(span)
del span['height'], span['width']
return need_ocr_spans
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
LINE_START_FLAG = ('(', '(', '"', '“', '【', '{', '《', '<', '「', '『', '【', '[',)
def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
char_center_x = (char_bbox[0] + char_bbox[2]) / 2
char_center_y = (char_bbox[1] + char_bbox[3]) / 2
span_center_y = (span_bbox[1] + span_bbox[3]) / 2
span_height = span_bbox[3] - span_bbox[1]
if (
span_bbox[0] < char_center_x < span_bbox[2]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height * span_height_radio # 字符的中轴和span的中轴高度差不能超过1/4span高度
):
return True
else:
# 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
# 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
if char in LINE_STOP_FLAG:
if (
(span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
and char_center_x > span_bbox[0]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height * span_height_radio
):
return True
elif char in LINE_START_FLAG:
if (
span_bbox[0] < char_bbox[2] < (span_bbox[0] + span_height)
and char_center_x < span_bbox[2]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height * span_height_radio
):
return True
else:
return False
def chars_to_content(span):
# 检查span中的char是否为空
if len(span['chars']) == 0:
pass
else:
# 先给chars按char['bbox']的中心点的x坐标排序
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
# Calculate the width of each character
char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
# Calculate the median width
median_width = statistics.median(char_widths)
# 通过x轴重叠比率移除一部分char
span = remove_x_overlapping_chars(span, median_width)
content = ''
for char in span['chars']:
# 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
char1 = char
char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['char'] != ' ' and char2['char'] != ' ':
content += f"{char['char']} "
else:
content += char['char']
content = __replace_unicode(content)
content = __replace_ligatures(content)
content = __replace_ligatures(content)
span['content'] = content.strip()
del span['chars']
def remove_x_overlapping_chars(span, median_width):
"""
Remove characters from a span that overlap significantly on the x-axis.
Args:
median_width:
span (dict): A span containing a list of chars, each with bbox coordinates
in the format [x0, y0, x1, y1]
Returns:
dict: The span with overlapping characters removed
"""
if 'chars' not in span or len(span['chars']) < 2:
return span
overlap_threshold = median_width * 0.3
i = 0
while i < len(span['chars']) - 1:
char1 = span['chars'][i]
char2 = span['chars'][i + 1]
# Calculate overlap width
x_left = max(char1['bbox'][0], char2['bbox'][0])
x_right = min(char1['bbox'][2], char2['bbox'][2])
if x_right > x_left: # There is overlap
overlap_width = x_right - x_left
if overlap_width > overlap_threshold:
if char1['char'] == char2['char'] or char1['char'] == ' ' or char2['char'] == ' ':
# Determine which character to remove
width1 = char1['bbox'][2] - char1['bbox'][0]
width2 = char2['bbox'][2] - char2['bbox'][0]
if width1 < width2:
# Remove the narrower character
span['chars'].pop(i)
else:
span['chars'].pop(i + 1)
else:
i += 1
# Don't increment i since we need to check the new pair
else:
i += 1
else:
i += 1
return span
def calculate_contrast(img, img_mode) -> float:
"""
计算给定图像的对比度。
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment