"tests/mixed.yaml" did not exist on "ba97c9f719fc22967fc09a81c2e10c9823cf42f2"
pdf_classify.py 7.89 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# Copyright (c) Opendatalab. All rights reserved.
import re
from io import BytesIO
import numpy as np
import pypdfium2 as pdfium
from loguru import logger
from pdfminer.high_level import extract_text
from pdfminer.pdfparser import PDFParser
from pdfminer.pdfdocument import PDFDocument
from pdfminer.pdfpage import PDFPage
from pdfminer.pdfinterp import PDFResourceManager
from pdfminer.pdfinterp import PDFPageInterpreter
from pdfminer.layout import LAParams, LTImage, LTFigure
from pdfminer.converter import PDFPageAggregator


def classify(pdf_bytes):
    """
    判断PDF文件是可以直接提取文本还是需要OCR

    Args:
        pdf_bytes: PDF文件的字节数据

    Returns:
        str: 'txt' 表示可以直接提取文本,'ocr' 表示需要OCR
    """

    # 从字节数据加载PDF
    sample_pdf_bytes = extract_pages(pdf_bytes)
    pdf = pdfium.PdfDocument(sample_pdf_bytes)
    try:
        # 获取PDF页数
        page_count = len(pdf)

        # 如果PDF页数为0,直接返回OCR
        if page_count == 0:
            return 'ocr'

        # 检查的页面数(最多检查10页)
        pages_to_check = min(page_count, 10)

        # 设置阈值:如果每页平均少于50个有效字符,认为需要OCR
        chars_threshold = 50

        # 检查平均字符数和无效字符
        if (get_avg_cleaned_chars_per_page(pdf, pages_to_check) < chars_threshold) or detect_invalid_chars(sample_pdf_bytes):
            return 'ocr'

        # 检查图像覆盖率
        if get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check) >= 0.8:
            return 'ocr'

        return 'txt'

    except Exception as e:
        logger.error(f"判断PDF类型时出错: {e}")
        # 出错时默认使用OCR
        return 'ocr'

    finally:
        # 无论执行哪个路径,都确保PDF被关闭
        pdf.close()


def get_avg_cleaned_chars_per_page(pdf_doc, pages_to_check):
    # 总字符数
    total_chars = 0
    # 清理后的总字符数
    cleaned_total_chars = 0

    # 检查前几页的文本
    for i in range(pages_to_check):
        page = pdf_doc[i]
        text_page = page.get_textpage()
        text = text_page.get_text_bounded()
        total_chars += len(text)

        # 清理提取的文本,移除空白字符
        cleaned_text = re.sub(r'\s+', '', text)
        cleaned_total_chars += len(cleaned_text)

    # 计算平均每页字符数
    avg_cleaned_chars_per_page = cleaned_total_chars / pages_to_check

    # logger.debug(f"PDF分析: 平均每页清理后{avg_cleaned_chars_per_page:.1f}字符")

    return avg_cleaned_chars_per_page


def get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check):
    # 创建内存文件对象
    pdf_stream = BytesIO(sample_pdf_bytes)

    # 创建PDF解析器
    parser = PDFParser(pdf_stream)

    # 创建PDF文档对象
    document = PDFDocument(parser)

    # 检查文档是否允许文本提取
    if not document.is_extractable:
        # logger.warning("PDF不允许内容提取")
        return 1.0  # 默认为高覆盖率,因为无法提取内容

    # 创建资源管理器和参数对象
    rsrcmgr = PDFResourceManager()
    laparams = LAParams(
        line_overlap=0.5,
        char_margin=2.0,
        line_margin=0.5,
        word_margin=0.1,
        boxes_flow=None,
        detect_vertical=False,
        all_texts=False,
    )

    # 创建聚合器
    device = PDFPageAggregator(rsrcmgr, laparams=laparams)

    # 创建解释器
    interpreter = PDFPageInterpreter(rsrcmgr, device)

    # 记录高图像覆盖率的页面数量
    high_image_coverage_pages = 0
    page_count = 0

    # 遍历页面
    for page in PDFPage.create_pages(document):
        # 控制检查的页数
        if page_count >= pages_to_check:
            break

        # 处理页面
        interpreter.process_page(page)
        layout = device.get_result()

        # 页面尺寸
        page_width = layout.width
        page_height = layout.height
        page_area = page_width * page_height

        # 计算图像覆盖的总面积
        image_area = 0

        # 遍历页面元素
        for element in layout:
            # 检查是否为图像或图形元素
            if isinstance(element, (LTImage, LTFigure)):
                # 计算图像边界框面积
                img_width = element.width
                img_height = element.height
                img_area = img_width * img_height
                image_area += img_area

        # 计算覆盖率
        coverage_ratio = min(image_area / page_area, 1.0) if page_area > 0 else 0
        # logger.debug(f"PDF分析: 页面 {page_count + 1} 图像覆盖率: {coverage_ratio:.2f}")

        # 判断是否为高覆盖率
        if coverage_ratio >= 0.8:  # 使用80%作为高覆盖率的阈值
            high_image_coverage_pages += 1

        page_count += 1

    # 关闭资源
    pdf_stream.close()

    # 如果没有处理任何页面,返回0
    if page_count == 0:
        return 0.0

    # 计算高图像覆盖率的页面比例
    high_coverage_ratio = high_image_coverage_pages / page_count
    # logger.debug(f"PDF分析: 高图像覆盖页面比例: {high_coverage_ratio:.2f}")

    return high_coverage_ratio


def extract_pages(src_pdf_bytes: bytes) -> bytes:
    """
    从PDF字节数据中随机提取最多10页,返回新的PDF字节数据

    Args:
        src_pdf_bytes: PDF文件的字节数据

    Returns:
        bytes: 提取页面后的PDF字节数据
    """

    # 从字节数据加载PDF
    pdf = pdfium.PdfDocument(src_pdf_bytes)

    # 获取PDF页数
    total_page = len(pdf)
    if total_page == 0:
        # 如果PDF没有页面,直接返回空文档
        logger.warning("PDF is empty, return empty document")
        return b''

    # 选择最多10页
    select_page_cnt = min(10, total_page)

    # 从总页数中随机选择页面
    page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()

    # 创建一个新的PDF文档
    sample_docs = pdfium.PdfDocument.new()

    try:
        # 将选择的页面导入新文档
        sample_docs.import_pages(pdf, page_indices)
        pdf.close()

        # 将新PDF保存到内存缓冲区
        output_buffer = BytesIO()
        sample_docs.save(output_buffer)

        # 获取字节数据
        return output_buffer.getvalue()
    except Exception as e:
        pdf.close()
        logger.exception(e)
        return b''  # 出错时返回空字节


def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
    """"
    检测PDF中是否包含非法字符
    """
    '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
    # sample_pdf_bytes = extract_pages(src_pdf_bytes)
    sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
    laparams = LAParams(
        line_overlap=0.5,
        char_margin=2.0,
        line_margin=0.5,
        word_margin=0.1,
        boxes_flow=None,
        detect_vertical=False,
        all_texts=False,
    )
    text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
    text = text.replace("\n", "")
    # logger.info(text)
    '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
    cid_pattern = re.compile(r'\(cid:\d+\)')
    matches = cid_pattern.findall(text)
    cid_count = len(matches)
    cid_len = sum(len(match) for match in matches)
    text_len = len(text)
    if text_len == 0:
        cid_chars_radio = 0
    else:
        cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
    # logger.debug(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
    '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
    if cid_chars_radio > 0.05:
        return True  # 乱码文档
    else:
        return False   # 正常文档


if __name__ == '__main__':
    with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f:
        p_bytes = f.read()
        logger.info(f"PDF分类结果: {classify(p_bytes)}")