Unverified Commit 6a3cdb8d authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1970 from myhloli/dev

feat(pre_proc): add function to remove x-overlapping characters in spans
parents 07eaa2d7 a2808f3a
...@@ -30,7 +30,6 @@ class UnimernetModel(object): ...@@ -30,7 +30,6 @@ class UnimernetModel(object):
self.model = self.model.to(dtype=torch.float16) self.model = self.model.to(dtype=torch.float16)
self.model.eval() self.model.eval()
def predict(self, mfd_res, image): def predict(self, mfd_res, image):
formula_list = [] formula_list = []
mf_image_list = [] mf_image_list = []
......
...@@ -34,7 +34,7 @@ from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table ...@@ -34,7 +34,7 @@ from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2 from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, \ from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, check_chars_is_overlap_in_span remove_overlaps_min_spans, remove_x_overlapping_chars
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -56,14 +56,6 @@ def __replace_STX_ETX(text_str: str): ...@@ -56,14 +56,6 @@ def __replace_STX_ETX(text_str: str):
return text_str return text_str
def __replace_0xfffd(text_str: str):
"""Replace \ufffd, as these characters become garbled when extracted using pymupdf."""
if text_str:
s = text_str.replace('\ufffd', " ")
return s
return text_str
# 连写字符拆分 # 连写字符拆分
def __replace_ligatures(text: str): def __replace_ligatures(text: str):
ligatures = { ligatures = {
...@@ -76,16 +68,17 @@ def chars_to_content(span): ...@@ -76,16 +68,17 @@ def chars_to_content(span):
# 检查span中的char是否为空 # 检查span中的char是否为空
if len(span['chars']) == 0: if len(span['chars']) == 0:
pass pass
# span['content'] = ''
elif check_chars_is_overlap_in_span(span['chars']):
pass
else: else:
# 先给chars按char['bbox']的中心点的x坐标排序 # 先给chars按char['bbox']的中心点的x坐标排序
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2) span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
# 求char的平均宽度 # Calculate the width of each character
char_width_sum = sum([char['bbox'][2] - char['bbox'][0] for char in span['chars']]) char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
char_avg_width = char_width_sum / len(span['chars']) # Calculate the median width
median_width = statistics.median(char_widths)
# 通过x轴重叠比率移除一部分char
span = remove_x_overlapping_chars(span, median_width)
content = '' content = ''
for char in span['chars']: for char in span['chars']:
...@@ -93,13 +86,12 @@ def chars_to_content(span): ...@@ -93,13 +86,12 @@ def chars_to_content(span):
# 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格 # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
char1 = char char1 = char
char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
if char2 and char2['bbox'][0] - char1['bbox'][2] > char_avg_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ': if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ':
content += f"{char['c']} " content += f"{char['c']} "
else: else:
content += char['c'] content += char['c']
content = __replace_ligatures(content) span['content'] = __replace_ligatures(content)
span['content'] = __replace_0xfffd(content)
del span['chars'] del span['chars']
...@@ -114,10 +106,6 @@ def fill_char_in_spans(spans, all_chars): ...@@ -114,10 +106,6 @@ def fill_char_in_spans(spans, all_chars):
spans = sorted(spans, key=lambda x: x['bbox'][1]) spans = sorted(spans, key=lambda x: x['bbox'][1])
for char in all_chars: for char in all_chars:
# 跳过非法bbox的char
# x1, y1, x2, y2 = char['bbox']
# if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
# continue
for span in spans: for span in spans:
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']): if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
......
...@@ -41,6 +41,55 @@ def check_chars_is_overlap_in_span(chars): ...@@ -41,6 +41,55 @@ def check_chars_is_overlap_in_span(chars):
return False return False
def remove_x_overlapping_chars(span, median_width):
"""
Remove characters from a span that overlap significantly on the x-axis.
Args:
median_width:
span (dict): A span containing a list of chars, each with bbox coordinates
in the format [x0, y0, x1, y1]
Returns:
dict: The span with overlapping characters removed
"""
if 'chars' not in span or len(span['chars']) < 2:
return span
overlap_threshold = median_width * 0.3
i = 0
while i < len(span['chars']) - 1:
char1 = span['chars'][i]
char2 = span['chars'][i + 1]
# Calculate overlap width
x_left = max(char1['bbox'][0], char2['bbox'][0])
x_right = min(char1['bbox'][2], char2['bbox'][2])
if x_right > x_left: # There is overlap
overlap_width = x_right - x_left
if overlap_width > overlap_threshold:
# Determine which character to remove
width1 = char1['bbox'][2] - char1['bbox'][0]
width2 = char2['bbox'][2] - char2['bbox'][0]
if width1 < width2:
# Remove the narrower character
span['chars'].pop(i)
else:
span['chars'].pop(i + 1)
# Don't increment i since we need to check the new pair
else:
i += 1
else:
i += 1
return span
def remove_overlaps_min_spans(spans): def remove_overlaps_min_spans(spans):
dropped_spans = [] dropped_spans = []
# 删除重叠spans中较小的那些 # 删除重叠spans中较小的那些
......
...@@ -4,12 +4,12 @@ from huggingface_hub import snapshot_download ...@@ -4,12 +4,12 @@ from huggingface_hub import snapshot_download
if __name__ == "__main__": if __name__ == "__main__":
mineru_patterns = [ mineru_patterns = [
"models/Layout/LayoutLMv3/*", # "models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*", "models/Layout/YOLO/*",
"models/MFD/YOLO/*", "models/MFD/YOLO/*",
"models/MFR/unimernet_small_2501/*", "models/MFR/unimernet_hf_small_2503/*",
"models/TabRec/TableMaster/*", # "models/TabRec/TableMaster/*",
"models/TabRec/StructEqTable/*", # "models/TabRec/StructEqTable/*",
] ]
model_dir = snapshot_download( model_dir = snapshot_download(
"opendatalab/PDF-Extract-Kit-1.0", "opendatalab/PDF-Extract-Kit-1.0",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment