test_e2e.py 9.73 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Opendatalab. All rights reserved.
import copy
import json
import os
from pathlib import Path
from loguru import logger
Sidney233's avatar
Sidney233 committed
7
8
from bs4 import BeautifulSoup
from fuzzywuzzy import fuzz
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from mineru.cli.common import (
    convert_pdf_bytes_to_bytes_by_pypdfium2,
    prepare_env,
    read_fn,
)
from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.enum_class import MakeMode
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import (
    union_make as pipeline_union_make,
)
from mineru.backend.pipeline.model_json_to_middle_json import (
    result_to_middle_json as pipeline_result_to_middle_json,
)
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make


Sidney233's avatar
Sidney233 committed
27
28
29
30
31
32
33
34
35
36
37
38
def test_pipeline_with_two_config():
    __dir__ = os.path.dirname(os.path.abspath(__file__))
    pdf_files_dir = os.path.join(__dir__, "pdfs")
    output_dir = os.path.join(__dir__, "output")
    pdf_suffixes = [".pdf"]
    image_suffixes = [".png", ".jpeg", ".jpg"]

    doc_path_list = []
    for doc_path in Path(pdf_files_dir).glob("*"):
        if doc_path.suffix in pdf_suffixes + image_suffixes:
            doc_path_list.append(doc_path)

Sidney233's avatar
Sidney233 committed
39
    # os.environ["MINERU_MODEL_SOURCE"] = "modelscope"
Sidney233's avatar
Sidney233 committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

    pdf_file_names = []
    pdf_bytes_list = []
    p_lang_list = []
    for path in doc_path_list:
        file_name = str(Path(path).stem)
        pdf_bytes = read_fn(path)
        pdf_file_names.append(file_name)
        pdf_bytes_list.append(pdf_bytes)
        p_lang_list.append("en")
    for idx, pdf_bytes in enumerate(pdf_bytes_list):
        new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes)
        pdf_bytes_list[idx] = new_pdf_bytes

    # 获取 pipline 分析结果, 分别测试 txt 和 ocr 两种解析方法的结果
    infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
        pipeline_doc_analyze(
            pdf_bytes_list,
            p_lang_list,
            parse_method="txt",
        )
    )
    write_infer_result(
        infer_results,
        all_image_lists,
        all_pdf_docs,
        lang_list,
        ocr_enabled_list,
        pdf_file_names,
        output_dir,
        parse_method="txt",
    )
    assert_content("tests/unittest/output/test/txt/test_content_list.json")
    infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
        pipeline_doc_analyze(
            pdf_bytes_list,
            p_lang_list,
            parse_method="ocr",
        )
    )
    write_infer_result(
        infer_results,
        all_image_lists,
        all_pdf_docs,
        lang_list,
        ocr_enabled_list,
        pdf_file_names,
        output_dir,
        parse_method="ocr",
    )
    assert_content("tests/unittest/output/test/ocr/test_content_list.json")


def test_vlm_transformers_with_default_config():
    __dir__ = os.path.dirname(os.path.abspath(__file__))
    pdf_files_dir = os.path.join(__dir__, "pdfs")
    output_dir = os.path.join(__dir__, "output")
    pdf_suffixes = [".pdf"]
    image_suffixes = [".png", ".jpeg", ".jpg"]

    doc_path_list = []
    for doc_path in Path(pdf_files_dir).glob("*"):
        if doc_path.suffix in pdf_suffixes + image_suffixes:
            doc_path_list.append(doc_path)

Sidney233's avatar
Sidney233 committed
105
    # os.environ["MINERU_MODEL_SOURCE"] = "modelscope"
Sidney233's avatar
Sidney233 committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    pdf_file_names = []
    pdf_bytes_list = []
    p_lang_list = []
    for path in doc_path_list:
        file_name = str(Path(path).stem)
        pdf_bytes = read_fn(path)
        pdf_file_names.append(file_name)
        pdf_bytes_list.append(pdf_bytes)
        p_lang_list.append("en")

    for idx, pdf_bytes in enumerate(pdf_bytes_list):
        pdf_file_name = pdf_file_names[idx]
        pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes)
        local_image_dir, local_md_dir = prepare_env(
            output_dir, pdf_file_name, parse_method="vlm"
        )
        image_writer, md_writer = FileBasedDataWriter(
            local_image_dir
        ), FileBasedDataWriter(local_md_dir)
        middle_json, infer_result = vlm_doc_analyze(
            pdf_bytes, image_writer=image_writer, backend="transformers"
        )

        pdf_info = middle_json["pdf_info"]

        image_dir = str(os.path.basename(local_image_dir))

        md_content_str = vlm_union_make(pdf_info, MakeMode.MM_MD, image_dir)
        md_writer.write_string(
            f"{pdf_file_name}.md",
            md_content_str,
        )

        content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
        md_writer.write_string(
            f"{pdf_file_name}_content_list.json",
            json.dumps(content_list, ensure_ascii=False, indent=4),
        )

        md_writer.write_string(
            f"{pdf_file_name}_middle.json",
            json.dumps(middle_json, ensure_ascii=False, indent=4),
        )

        model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
        md_writer.write_string(
            f"{pdf_file_name}_model_output.txt",
            model_output,
        )

        logger.info(f"local output dir is {local_md_dir}")
        assert_content("tests/unittest/output/test/vlm/test_content_list.json")


def write_infer_result(
    infer_results,
    all_image_lists,
    all_pdf_docs,
    lang_list,
    ocr_enabled_list,
    pdf_file_names,
    output_dir,
    parse_method,
):
    for idx, model_list in enumerate(infer_results):
        model_json = copy.deepcopy(model_list)
        pdf_file_name = pdf_file_names[idx]
        local_image_dir, local_md_dir = prepare_env(
            output_dir, pdf_file_name, parse_method
        )
        image_writer, md_writer = FileBasedDataWriter(
            local_image_dir
        ), FileBasedDataWriter(local_md_dir)

        images_list = all_image_lists[idx]
        pdf_doc = all_pdf_docs[idx]
        _lang = lang_list[idx]
        _ocr_enable = ocr_enabled_list[idx]
        middle_json = pipeline_result_to_middle_json(
            model_list,
            images_list,
            pdf_doc,
            image_writer,
            _lang,
            _ocr_enable,
            True,
        )

        pdf_info = middle_json["pdf_info"]

        image_dir = str(os.path.basename(local_image_dir))
        # 写入 md 文件
        md_content_str = pipeline_union_make(pdf_info, MakeMode.MM_MD, image_dir)
        md_writer.write_string(
            f"{pdf_file_name}.md",
            md_content_str,
        )

        content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
        md_writer.write_string(
            f"{pdf_file_name}_content_list.json",
            json.dumps(content_list, ensure_ascii=False, indent=4),
        )

        md_writer.write_string(
            f"{pdf_file_name}_middle.json",
            json.dumps(middle_json, ensure_ascii=False, indent=4),
        )

        md_writer.write_string(
            f"{pdf_file_name}_model.json",
            json.dumps(model_json, ensure_ascii=False, indent=4),
        )

        logger.info(f"local output dir is {local_md_dir}")


def validate_html(html_content):
    try:
        soup = BeautifulSoup(html_content, "html.parser")
        return True
    except Exception as e:
        return False


def assert_content(content_path):
    content_list = []
    with open(content_path, "r", encoding="utf-8") as file:
        content_list = json.load(file)
    type_set = set()
    for content_dict in content_list:
        match content_dict["type"]:
            # 图片校验,只校验 Caption
            case "image":
                type_set.add("image")
                assert (
                    content_dict["image_caption"][0].strip().lower()
                    == "Figure 1: Figure Caption".lower()
245
                )
Sidney233's avatar
Sidney233 committed
246
247
248
249
250
251
            # 表格校验,校验 Caption,表格格式和表格内容
            case "table":
                type_set.add("table")
                assert (
                    content_dict["table_caption"][0].strip().lower()
                    == "Table 1: Table Caption".lower()
252
                )
Sidney233's avatar
Sidney233 committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                assert validate_html(content_dict["table_body"])
                target_str_list = [
                    "Linear Regression",
                    "0.98740",
                    "1321.2",
                    "2-order Polynomial",
                    "0.99906",
                    "26.4",
                    "3-order Polynomial",
                    "0.99913",
                    "101.2",
                    "4-order Polynomial",
                    "0.99914",
                    "94.1",
                    "Gray Prediction",
                    "0.00617",
                    "687",
                ]
                correct_count = 0
                for target_str in target_str_list:
                    if target_str in content_dict["table_body"]:
                        correct_count += 1

                assert correct_count > 0.9 * len(target_str_list)
            # 公式校验,检测是否含有公式元素
            case "equation":
                type_set.add("equation")
                target_str_list = ["$$", "lambda", "frac", "bar"]
                for target_str in target_str_list:
                    assert target_str in content_dict["text"]
            # 文本校验,文本相似度超过90
            case "text":
                type_set.add("text")
                assert (
                    fuzz.ratio(
                        content_dict["text"],
                        "Trump graduated from the Wharton School of the University of Pennsylvania with a bachelor's degree in 1968. He became president of his father's real estate business in 1971 and renamed it The Trump Organization.",
290
                    )
Sidney233's avatar
Sidney233 committed
291
292
293
                    > 90
                )
    assert len(type_set) >= 4