Unverified Commit 6e1fba93 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1664 from myhloli/dev

feat(pdf_parse): improve OCR processing and contrast filtering 
parents 9bb2d581 5561ac95
...@@ -23,7 +23,7 @@ def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod: ...@@ -23,7 +23,7 @@ def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
pdf_meta['image_info_per_page'], pdf_meta['image_info_per_page'],
pdf_meta['text_len_per_page'], pdf_meta['text_len_per_page'],
pdf_meta['imgs_per_page'], pdf_meta['imgs_per_page'],
pdf_meta['text_layout_per_page'], # pdf_meta['text_layout_per_page'],
pdf_meta['invalid_chars'], pdf_meta['invalid_chars'],
) )
if is_text_pdf: if is_text_pdf:
......
...@@ -305,7 +305,8 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list): ...@@ -305,7 +305,8 @@ def classify_by_img_narrow_strips(page_width, page_height, img_sz_list):
def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list, def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list,
text_layout_list: list, invalid_chars: bool): # text_layout_list: list,
invalid_chars: bool):
""" """
这里的图片和页面长度单位是pts 这里的图片和页面长度单位是pts
:param total_page: :param total_page:
...@@ -321,7 +322,7 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l ...@@ -321,7 +322,7 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
'by_text_len': classify_by_text_len(text_len_list, total_page), 'by_text_len': classify_by_text_len(text_len_list, total_page),
'by_avg_words': classify_by_avg_words(text_len_list), 'by_avg_words': classify_by_avg_words(text_len_list),
'by_img_num': classify_by_img_num(img_sz_list, img_num_list), 'by_img_num': classify_by_img_num(img_sz_list, img_num_list),
'by_text_layout': classify_by_text_layout(text_layout_list), # 'by_text_layout': classify_by_text_layout(text_layout_list),
'by_img_narrow_strips': classify_by_img_narrow_strips(page_width, page_height, img_sz_list), 'by_img_narrow_strips': classify_by_img_narrow_strips(page_width, page_height, img_sz_list),
'by_invalid_chars': invalid_chars, 'by_invalid_chars': invalid_chars,
} }
...@@ -332,9 +333,10 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l ...@@ -332,9 +333,10 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
return False, results return False, results
else: else:
logger.warning( logger.warning(
f"pdf is not classified by area and text_len, by_image_area: {results['by_image_area']}," f"OCR needed based on classification result, by_image_area: {results['by_image_area']},"
f" by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']}," f" by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']},"
f" by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']}," # f" by_text_layout: {results['by_text_layout']},"
f" by_img_narrow_strips: {results['by_img_narrow_strips']},"
f" by_invalid_chars: {results['by_invalid_chars']}", f" by_invalid_chars: {results['by_invalid_chars']}",
file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法 file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
return False, results return False, results
......
...@@ -356,9 +356,9 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -356,9 +356,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}") # logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
text_len_per_page = get_pdf_textlen_per_page(doc) text_len_per_page = get_pdf_textlen_per_page(doc)
# logger.info(f"text_len_per_page: {text_len_per_page}") # logger.info(f"text_len_per_page: {text_len_per_page}")
text_layout_per_page = get_pdf_text_layout_per_page(doc) # text_layout_per_page = get_pdf_text_layout_per_page(doc)
# logger.info(f"text_layout_per_page: {text_layout_per_page}") # logger.info(f"text_layout_per_page: {text_layout_per_page}")
text_language = get_language(doc) # text_language = get_language(doc)
# logger.info(f"text_language: {text_language}") # logger.info(f"text_language: {text_language}")
invalid_chars = check_invalid_chars(pdf_bytes) invalid_chars = check_invalid_chars(pdf_bytes)
# logger.info(f"invalid_chars: {invalid_chars}") # logger.info(f"invalid_chars: {invalid_chars}")
...@@ -372,8 +372,8 @@ def pdf_meta_scan(pdf_bytes: bytes): ...@@ -372,8 +372,8 @@ def pdf_meta_scan(pdf_bytes: bytes):
'page_height_pts': int(page_height_pts), 'page_height_pts': int(page_height_pts),
'image_info_per_page': image_info_per_page, 'image_info_per_page': image_info_per_page,
'text_len_per_page': text_len_per_page, 'text_len_per_page': text_len_per_page,
'text_layout_per_page': text_layout_per_page, # 'text_layout_per_page': text_layout_per_page,
'text_language': text_language, # 'text_language': text_language,
# "svgs_per_page": svgs_per_page, # "svgs_per_page": svgs_per_page,
'imgs_per_page': imgs_per_page, # 增加每页img数量list 'imgs_per_page': imgs_per_page, # 增加每页img数量list
'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list 'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
......
...@@ -4,6 +4,7 @@ from loguru import logger ...@@ -4,6 +4,7 @@ from loguru import logger
import re import re
from io import BytesIO from io import BytesIO
from pdfminer.high_level import extract_text from pdfminer.high_level import extract_text
from pdfminer.layout import LAParams
def calculate_sample_count(total_page: int): def calculate_sample_count(total_page: int):
...@@ -41,7 +42,16 @@ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool: ...@@ -41,7 +42,16 @@ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
sample_docs = extract_pages(src_pdf_bytes) sample_docs = extract_pages(src_pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes() sample_pdf_bytes = sample_docs.tobytes()
sample_pdf_file_like_object = BytesIO(sample_pdf_bytes) sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
text = extract_text(sample_pdf_file_like_object) laparams = LAParams(
line_overlap=0.5,
char_margin=2.0,
line_margin=0.5,
word_margin=0.1,
boxes_flow=None,
detect_vertical=False,
all_texts=False,
)
text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
text = text.replace("\n", "") text = text.replace("\n", "")
# logger.info(text) # logger.info(text)
'''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)''' '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import time
from collections import Counter from collections import Counter
from uuid import uuid4 from uuid import uuid4
...@@ -102,9 +103,9 @@ class YOLOv11LangDetModel(object): ...@@ -102,9 +103,9 @@ class YOLOv11LangDetModel(object):
temp_images = split_images(image) temp_images = split_images(image)
for temp_image in temp_images: for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image)) all_images.append(resize_images_to_224(temp_image))
# langdetect_start = time.time()
images_lang_res = self.batch_predict(all_images, batch_size=8) images_lang_res = self.batch_predict(all_images, batch_size=256)
# logger.info(f"images_lang_res: {images_lang_res}") # logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
if len(images_lang_res) > 0: if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res) count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get) language = max(count_dict, key=count_dict.get)
......
...@@ -6,8 +6,10 @@ import statistics ...@@ -6,8 +6,10 @@ import statistics
import time import time
from typing import List from typing import List
import cv2
import fitz import fitz
import torch import torch
import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
...@@ -127,16 +129,15 @@ def fill_char_in_spans(spans, all_chars): ...@@ -127,16 +129,15 @@ def fill_char_in_spans(spans, all_chars):
span['chars'].append(char) span['chars'].append(char)
break break
empty_spans = [] need_ocr_spans = []
for span in spans: for span in spans:
chars_to_content(span) chars_to_content(span)
# 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤 # 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
if len(span['content']) * span['height'] < span['width'] * 0.5: if len(span['content']) * span['height'] < span['width'] * 0.5:
# logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}") # logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
empty_spans.append(span) need_ocr_spans.append(span)
del span['height'], span['width'] del span['height'], span['width']
return empty_spans return need_ocr_spans
# 使用鲁棒性更强的中心点坐标判断 # 使用鲁棒性更强的中心点坐标判断
...@@ -190,6 +191,31 @@ def remove_tilted_line(text_blocks): ...@@ -190,6 +191,31 @@ def remove_tilted_line(text_blocks):
block['lines'].remove(line) block['lines'].remove(line)
def calculate_contrast(img, img_mode) -> float:
"""
计算给定图像的对比度。
:param img: 图像,类型为numpy.ndarray
:Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
:return: 图像的对比度值
"""
if img_mode == 'rgb':
# 将RGB图像转换为灰度图
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
elif img_mode == 'bgr':
# 将BGR图像转换为灰度图
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
# 计算均值和标准差
mean_value = np.mean(gray_img)
std_dev = np.std(gray_img)
# 对比度定义为标准差除以平均值(加上小常数避免除零错误)
contrast = std_dev / (mean_value + 1e-6)
# logger.info(f"contrast: {contrast}")
return round(contrast, 2)
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang): def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
# cid用0xfffd表示,连字符拆开 # cid用0xfffd表示,连字符拆开
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] # text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
...@@ -274,9 +300,9 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -274,9 +300,9 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
span['chars'] = [] span['chars'] = []
new_spans.append(span) new_spans.append(span)
empty_spans = fill_char_in_spans(new_spans, all_pymu_chars) need_ocr_spans = fill_char_in_spans(new_spans, all_pymu_chars)
if len(empty_spans) > 0: if len(need_ocr_spans) > 0:
# 初始化ocr模型 # 初始化ocr模型
atom_model_manager = AtomModelSingleton() atom_model_manager = AtomModelSingleton()
...@@ -287,9 +313,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -287,9 +313,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
lang=lang lang=lang
) )
for span in empty_spans: for span in need_ocr_spans:
# 对span的bbox截图再ocr # 对span的bbox截图再ocr
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2') span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
# 计算span的对比度,低于0.20的span不进行ocr
if calculate_contrast(span_img, img_mode='bgr') <= 0.20:
spans.remove(span)
continue
ocr_res = ocr_model.ocr(span_img, det=False) ocr_res = ocr_model.ocr(span_img, det=False)
if ocr_res and len(ocr_res) > 0: if ocr_res and len(ocr_res) > 0:
if len(ocr_res[0]) > 0: if len(ocr_res[0]) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment