Unverified Commit 85a4750d authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #3026 from Sidney233/dev

Dev
parents 206ed770 a7e75dc0
import os
import pytest
from magic_pdf.data.data_reader_writer import MultiBucketS3DataReader
from magic_pdf.data.read_api import (read_jsonl, read_local_images,
read_local_pdfs)
from magic_pdf.data.schemas import S3Config
def test_read_local_pdfs():
datasets = read_local_pdfs('tests/unittest/test_data/assets/pdfs')
assert len(datasets) == 2
assert len(datasets[0]) > 0
assert len(datasets[1]) > 0
assert datasets[0].get_page(0).get_page_info().w > 0
assert datasets[0].get_page(0).get_page_info().h > 0
def test_read_local_images():
datasets = read_local_images('tests/unittest/test_data/assets/pngs', suffixes=['.png'])
assert len(datasets) == 2
assert len(datasets[0]) == 1
assert len(datasets[1]) == 1
assert datasets[0].get_page(0).get_page_info().w > 0
assert datasets[0].get_page(0).get_page_info().h > 0
@pytest.mark.skipif(
os.getenv('S3_ACCESS_KEY_2', None) is None, reason='need s3 config!'
)
def test_read_json():
"""test multi bucket s3 reader writer must config s3 config in the
environment export S3_BUCKET=xxx export S3_ACCESS_KEY=xxx export
S3_SECRET_KEY=xxx export S3_ENDPOINT=xxx.
export S3_BUCKET_2=xxx export S3_ACCESS_KEY_2=xxx export S3_SECRET_KEY_2=xxx export S3_ENDPOINT_2=xxx
"""
bucket = os.getenv('S3_BUCKET', '')
ak = os.getenv('S3_ACCESS_KEY', '')
sk = os.getenv('S3_SECRET_KEY', '')
endpoint_url = os.getenv('S3_ENDPOINT', '')
bucket_2 = os.getenv('S3_BUCKET_2', '')
ak_2 = os.getenv('S3_ACCESS_KEY_2', '')
sk_2 = os.getenv('S3_SECRET_KEY_2', '')
endpoint_url_2 = os.getenv('S3_ENDPOINT_2', '')
s3configs = [
S3Config(
bucket_name=bucket, access_key=ak, secret_key=sk, endpoint_url=endpoint_url
),
S3Config(
bucket_name=bucket_2,
access_key=ak_2,
secret_key=sk_2,
endpoint_url=endpoint_url_2,
),
]
reader = MultiBucketS3DataReader(bucket, s3configs)
datasets = read_jsonl(
f's3://{bucket}/meta-index/scihub/v001/scihub/part-66210c190659-000026.jsonl',
reader,
)
assert len(datasets) > 0
assert len(datasets[0]) == 10
datasets = read_jsonl('tests/unittest/test_data/assets/jsonl/test_01.jsonl', reader)
assert len(datasets) == 1
assert len(datasets[0]) == 10
datasets = read_jsonl('tests/unittest/test_data/assets/jsonl/test_02.jsonl')
assert len(datasets) == 1
assert len(datasets[0]) == 1
# Copyright (c) Opendatalab. All rights reserved.
import copy
import json
import os
from pathlib import Path
from loguru import logger
from bs4 import BeautifulSoup
from fuzzywuzzy import fuzz
from mineru.cli.common import (
convert_pdf_bytes_to_bytes_by_pypdfium2,
prepare_env,
read_fn,
)
from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.enum_class import MakeMode
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import (
union_make as pipeline_union_make,
)
from mineru.backend.pipeline.model_json_to_middle_json import (
result_to_middle_json as pipeline_result_to_middle_json,
)
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
def test_pipeline_with_two_config():
__dir__ = os.path.dirname(os.path.abspath(__file__))
pdf_files_dir = os.path.join(__dir__, "pdfs")
output_dir = os.path.join(__dir__, "output")
pdf_suffixes = [".pdf"]
image_suffixes = [".png", ".jpeg", ".jpg"]
doc_path_list = []
for doc_path in Path(pdf_files_dir).glob("*"):
if doc_path.suffix in pdf_suffixes + image_suffixes:
doc_path_list.append(doc_path)
os.environ["MINERU_MODEL_SOURCE"] = "modelscope"
pdf_file_names = []
pdf_bytes_list = []
p_lang_list = []
for path in doc_path_list:
file_name = str(Path(path).stem)
pdf_bytes = read_fn(path)
pdf_file_names.append(file_name)
pdf_bytes_list.append(pdf_bytes)
p_lang_list.append("en")
for idx, pdf_bytes in enumerate(pdf_bytes_list):
new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes)
pdf_bytes_list[idx] = new_pdf_bytes
# 获取 pipline 分析结果, 分别测试 txt 和 ocr 两种解析方法的结果
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
pipeline_doc_analyze(
pdf_bytes_list,
p_lang_list,
parse_method="txt",
)
)
write_infer_result(
infer_results,
all_image_lists,
all_pdf_docs,
lang_list,
ocr_enabled_list,
pdf_file_names,
output_dir,
parse_method="txt",
)
assert_content("tests/unittest/output/test/txt/test_content_list.json")
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
pipeline_doc_analyze(
pdf_bytes_list,
p_lang_list,
parse_method="ocr",
)
)
write_infer_result(
infer_results,
all_image_lists,
all_pdf_docs,
lang_list,
ocr_enabled_list,
pdf_file_names,
output_dir,
parse_method="ocr",
)
assert_content("tests/unittest/output/test/ocr/test_content_list.json")
def test_vlm_transformers_with_default_config():
__dir__ = os.path.dirname(os.path.abspath(__file__))
pdf_files_dir = os.path.join(__dir__, "pdfs")
output_dir = os.path.join(__dir__, "output")
pdf_suffixes = [".pdf"]
image_suffixes = [".png", ".jpeg", ".jpg"]
doc_path_list = []
for doc_path in Path(pdf_files_dir).glob("*"):
if doc_path.suffix in pdf_suffixes + image_suffixes:
doc_path_list.append(doc_path)
os.environ["MINERU_MODEL_SOURCE"] = "modelscope"
pdf_file_names = []
pdf_bytes_list = []
p_lang_list = []
for path in doc_path_list:
file_name = str(Path(path).stem)
pdf_bytes = read_fn(path)
pdf_file_names.append(file_name)
pdf_bytes_list.append(pdf_bytes)
p_lang_list.append("en")
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes)
local_image_dir, local_md_dir = prepare_env(
output_dir, pdf_file_name, parse_method="vlm"
)
image_writer, md_writer = FileBasedDataWriter(
local_image_dir
), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend="transformers"
)
pdf_info = middle_json["pdf_info"]
image_dir = str(os.path.basename(local_image_dir))
md_content_str = vlm_union_make(pdf_info, MakeMode.MM_MD, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
md_writer.write_string(
f"{pdf_file_name}_model_output.txt",
model_output,
)
logger.info(f"local output dir is {local_md_dir}")
assert_content("tests/unittest/output/test/vlm/test_content_list.json")
def write_infer_result(
infer_results,
all_image_lists,
all_pdf_docs,
lang_list,
ocr_enabled_list,
pdf_file_names,
output_dir,
parse_method,
):
for idx, model_list in enumerate(infer_results):
model_json = copy.deepcopy(model_list)
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(
output_dir, pdf_file_name, parse_method
)
image_writer, md_writer = FileBasedDataWriter(
local_image_dir
), FileBasedDataWriter(local_md_dir)
images_list = all_image_lists[idx]
pdf_doc = all_pdf_docs[idx]
_lang = lang_list[idx]
_ocr_enable = ocr_enabled_list[idx]
middle_json = pipeline_result_to_middle_json(
model_list,
images_list,
pdf_doc,
image_writer,
_lang,
_ocr_enable,
True,
)
pdf_info = middle_json["pdf_info"]
image_dir = str(os.path.basename(local_image_dir))
# 写入 md 文件
md_content_str = pipeline_union_make(pdf_info, MakeMode.MM_MD, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
md_writer.write_string(
f"{pdf_file_name}_model.json",
json.dumps(model_json, ensure_ascii=False, indent=4),
)
logger.info(f"local output dir is {local_md_dir}")
def validate_html(html_content):
try:
soup = BeautifulSoup(html_content, "html.parser")
return True
except Exception as e:
return False
def assert_content(content_path):
content_list = []
with open(content_path, "r", encoding="utf-8") as file:
content_list = json.load(file)
type_set = set()
for content_dict in content_list:
match content_dict["type"]:
# 图片校验,只校验 Caption
case "image":
type_set.add("image")
assert (
content_dict["image_caption"][0].strip().lower()
== "Figure 1: Figure Caption".lower()
)
# 表格校验,校验 Caption,表格格式和表格内容
case "table":
type_set.add("table")
assert (
content_dict["table_caption"][0].strip().lower()
== "Table 1: Table Caption".lower()
)
assert validate_html(content_dict["table_body"])
target_str_list = [
"Linear Regression",
"0.98740",
"1321.2",
"2-order Polynomial",
"0.99906",
"26.4",
"3-order Polynomial",
"0.99913",
"101.2",
"4-order Polynomial",
"0.99914",
"94.1",
"Gray Prediction",
"0.00617",
"687",
]
correct_count = 0
for target_str in target_str_list:
if target_str in content_dict["table_body"]:
correct_count += 1
assert correct_count > 0.9 * len(target_str_list)
# 公式校验,检测是否含有公式元素
case "equation":
type_set.add("equation")
target_str_list = ["$$", "lambda", "frac", "bar"]
for target_str in target_str_list:
assert target_str in content_dict["text"]
# 文本校验,文本相似度超过90
case "text":
type_set.add("text")
assert (
fuzz.ratio(
content_dict["text"],
"Trump graduated from the Wharton School of the University of Pennsylvania with a bachelor's degree in 1968. He became president of his father's real estate business in 1971 and renamed it The Trump Organization.",
)
> 90
)
assert len(type_set) >= 4
import json
import os
import shutil
import tempfile
from magic_pdf.integrations.rag.api import DataReader, RagDocumentReader
from magic_pdf.integrations.rag.type import CategoryType
from magic_pdf.integrations.rag.utils import \
convert_middle_json_to_layout_elements
def test_rag_document_reader():
# setup
unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
os.makedirs(unitest_dir, exist_ok=True)
temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
os.makedirs(temp_output_dir, exist_ok=True)
# test
with open('tests/unittest/test_integrations/test_rag/assets/middle.json') as f:
json_data = json.load(f)
res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
doc = RagDocumentReader(res)
assert len(list(iter(doc))) == 1
page = list(iter(doc))[0]
assert len(list(iter(page))) >= 10
assert len(page.get_rel_map()) >= 3
item = list(iter(page))[0]
assert item.category_type == CategoryType.text
# teardown
shutil.rmtree(temp_output_dir)
def test_data_reader():
# setup
unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
os.makedirs(unitest_dir, exist_ok=True)
temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
os.makedirs(temp_output_dir, exist_ok=True)
# test
data_reader = DataReader('tests/unittest/test_integrations/test_rag/assets', 'ocr',
temp_output_dir)
assert data_reader.get_documents_count() == 2
for idx in range(data_reader.get_documents_count()):
document = data_reader.get_document_result(idx)
assert document is not None
# teardown
shutil.rmtree(temp_output_dir)
import json
import os
import shutil
import tempfile
from magic_pdf.integrations.rag.type import CategoryType
from magic_pdf.integrations.rag.utils import (
convert_middle_json_to_layout_elements, inference)
def test_convert_middle_json_to_layout_elements():
# setup
unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
os.makedirs(unitest_dir, exist_ok=True)
temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
os.makedirs(temp_output_dir, exist_ok=True)
# test
with open('tests/unittest/test_integrations/test_rag/assets/middle.json') as f:
json_data = json.load(f)
res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
assert len(res) == 1
assert len(res[0].layout_dets) > 0
assert res[0].layout_dets[0].anno_id == 0
assert res[0].layout_dets[0].category_type == CategoryType.text
assert len(res[0].extra.element_relation) >= 2
# teardown
shutil.rmtree(temp_output_dir)
def test_inference():
asset_dir = 'tests/unittest/test_integrations/test_rag/assets'
# setup
unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
os.makedirs(unitest_dir, exist_ok=True)
temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
os.makedirs(temp_output_dir, exist_ok=True)
# test
res = inference(
asset_dir + '/one_page_with_table_image.pdf',
temp_output_dir,
'ocr',
)
assert res is not None
assert len(res) == 1
assert len(res[0].layout_dets) > 0
assert res[0].layout_dets[0].anno_id == 0
assert res[0].layout_dets[0].category_type == CategoryType.text
assert len(res[0].extra.element_relation) >= 2
# teardown
shutil.rmtree(temp_output_dir)
import os
import pytest
from magic_pdf.filter.pdf_classify_by_type import classify_by_area, classify_by_text_len, classify_by_avg_words, \
classify_by_img_num, classify_by_text_layout, classify_by_img_narrow_strips
from magic_pdf.filter.pdf_meta_scan import get_pdf_page_size_pts, get_pdf_textlen_per_page, get_imgs_per_page
from test_commons import get_docs_from_test_pdf, get_test_json_data
# 获取当前目录
current_directory = os.path.dirname(os.path.abspath(__file__))
'''
根据图片尺寸占页面面积的比例,判断是否为扫描版
'''
@pytest.mark.parametrize("book_name, expected_bool_classify_by_area",
[
("the_eye/the_eye_cdn_00391653", True), # 特殊文字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张
("scihub/scihub_08400000/libgen.scimag08489000-08489999.zip_10.1016/0370-1573(90)90070-i", False), # 特殊扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张
("zlib/zlib_17216416", False), # 特殊扫描版3,有的页面是一整张大图,有的页面是通过一条条小图拼起来的,检测图片占比之前需要先按规则把小图拼成大图
("the_eye/the_eye_wtl_00023799", False), # 特殊扫描版4,每一页都是一张张小图拼出来的,检测图片占比之前需要先按规则把小图拼成大图
("the_eye/the_eye_cdn_00328381", False), # 特殊扫描版5,每一页都是一张张小图拼出来的,存在多个小图多次重复使用情况,检测图片占比之前需要先按规则把小图拼成大图
("scihub/scihub_25800000/libgen.scimag25889000-25889999.zip_10.2307/4153991", False), # 特殊扫描版6,只有三页,其中两页是扫描版
("scanned_detection/llm-raw-scihub-o.O-0584-8539%2891%2980165-f", False), # 特殊扫描版7,只有一页且由小图拼成大图
("scanned_detection/llm-raw-scihub-o.O-bf01427123", False), # 特殊扫描版8,只有3页且全是大图扫描版
("scihub/scihub_41200000/libgen.scimag41253000-41253999.zip_10.1080/00222938709460256", False), # 特殊扫描版12,头两页文字版且有一页没图片,后面扫描版11页
("scihub/scihub_37000000/libgen.scimag37068000-37068999.zip_10.1080/0015587X.1936.9718622", False) # 特殊扫描版13,头两页文字版且有一页没图片,后面扫描版3页
])
def test_classify_by_area(book_name, expected_bool_classify_by_area):
test_data = get_test_json_data(current_directory, "test_metascan_classify_data.json")
docs = get_docs_from_test_pdf(book_name)
median_width, median_height = get_pdf_page_size_pts(docs)
page_width = int(median_width)
page_height = int(median_height)
img_sz_list = test_data[book_name]["expected_image_info"]
total_page = len(docs)
text_len_list = get_pdf_textlen_per_page(docs)
bool_classify_by_area = classify_by_area(total_page, page_width, page_height, img_sz_list, text_len_list)
# assert bool_classify_by_area == expected_bool_classify_by_area
'''
广义上的文字版检测,任何一页大于100字,都认为为文字版
'''
@pytest.mark.parametrize("book_name, expected_bool_classify_by_text_len",
[
("scihub/scihub_67200000/libgen.scimag67237000-67237999.zip_10.1515/crpm-2017-0020", True), # 文字版,少于50页
("scihub/scihub_83300000/libgen.scimag83306000-83306999.zip_10.1007/978-3-658-30153-8", True), # 文字版,多于50页
("zhongwenzaixian/zhongwenzaixian_65771414", False), # 完全无字的宣传册
])
def test_classify_by_text_len(book_name, expected_bool_classify_by_text_len):
docs = get_docs_from_test_pdf(book_name)
text_len_list = get_pdf_textlen_per_page(docs)
total_page = len(docs)
bool_classify_by_text_len = classify_by_text_len(text_len_list, total_page)
# assert bool_classify_by_text_len == expected_bool_classify_by_text_len
'''
狭义上的文字版检测,需要平均每页字数大于200字
'''
@pytest.mark.parametrize("book_name, expected_bool_classify_by_avg_words",
[
("zlib/zlib_21207669", False), # 扫描版,书末尾几页有大纲文字
("zlib/zlib_19012845", False), # 扫描版,好几本扫描书的集合,每本书末尾有一页文字页
("scihub/scihub_67200000/libgen.scimag67237000-67237999.zip_10.1515/crpm-2017-0020", True),# 正常文字版
("zhongwenzaixian/zhongwenzaixian_65771414", False), # 宣传册
("zhongwenzaixian/zhongwenzaixian_351879", False), # 图解书/无字or少字
("zhongwenzaixian/zhongwenzaixian_61357496_pdfvector", False), # 书法集
("zhongwenzaixian/zhongwenzaixian_63684541", False), # 设计图
("zhongwenzaixian/zhongwenzaixian_61525978", False), # 绘本
("zhongwenzaixian/zhongwenzaixian_63679729", False), # 摄影集
])
def test_classify_by_avg_words(book_name, expected_bool_classify_by_avg_words):
docs = get_docs_from_test_pdf(book_name)
text_len_list = get_pdf_textlen_per_page(docs)
bool_classify_by_avg_words = classify_by_avg_words(text_len_list)
# assert bool_classify_by_avg_words == expected_bool_classify_by_avg_words
'''
这个规则只针对特殊扫描版1,因为扫描版1的图片信息都由于junk_list的原因被舍弃了,只能通过图片数量来判断
'''
@pytest.mark.parametrize("book_name, expected_bool_classify_by_img_num",
[
("zlib/zlib_21370453", False), # 特殊扫描版1,每页都有所有扫描页图片,特点是图占比大,每页展示1至n张
("zlib/zlib_22115997", False), # 特殊扫描版2,类似特1,但是每页数量不完全相等
("zlib/zlib_21814957", False), # 特殊扫描版3,类似特1,但是每页数量不完全相等
("zlib/zlib_21814955", False), # 特殊扫描版4,类似特1,但是每页数量不完全相等
])
def test_classify_by_img_num(book_name, expected_bool_classify_by_img_num):
test_data = get_test_json_data(current_directory, "test_metascan_classify_data.json")
docs = get_docs_from_test_pdf(book_name)
img_num_list = get_imgs_per_page(docs)
img_sz_list = test_data[book_name]["expected_image_info"]
bool_classify_by_img_num = classify_by_img_num(img_sz_list, img_num_list)
# assert bool_classify_by_img_num == expected_bool_classify_by_img_num
'''
排除纵向排版的pdf
'''
@pytest.mark.parametrize("book_name, expected_bool_classify_by_text_layout",
[
("vertical_detection/三国演义_繁体竖排版", False), # 竖排版本1
("vertical_detection/净空法师_大乘无量寿", False), # 竖排版本2
("vertical_detection/om3006239", True), # 横排版本1
("vertical_detection/isit.2006.261791", True), # 横排版本2
])
def test_classify_by_text_layout(book_name, expected_bool_classify_by_text_layout):
test_data = get_test_json_data(current_directory, "test_metascan_classify_data.json")
text_layout_per_page = test_data[book_name]["expected_text_layout"]
bool_classify_by_text_layout = classify_by_text_layout(text_layout_per_page)
# assert bool_classify_by_text_layout == expected_bool_classify_by_text_layout
'''
通过检测页面是否由多个窄长条图像组成,来过滤特殊的扫描版
这个规则只对窄长条组成的pdf进行识别,而不会识别常规的大图扫描pdf
'''
@pytest.mark.parametrize("book_name, expected_bool_classify_by_img_narrow_strips",
[
("scihub/scihub_25900000/libgen.scimag25991000-25991999.zip_10.2307/40066695", False), # 特殊扫描版
("the_eye/the_eye_wtl_00023799", False), # 特殊扫描版4,每一页都是一张张小图拼出来的,检测图片占比之前需要先按规则把小图拼成大图
("the_eye/the_eye_cdn_00328381", False), # 特殊扫描版5,每一页都是一张张小图拼出来的,存在多个小图多次重复使用情况,检测图片占比之前需要先按规则把小图拼成大图
("scanned_detection/llm-raw-scihub-o.O-0584-8539%2891%2980165-f", False), # 特殊扫描版7,只有一页且由小图拼成大图
("scihub/scihub_25800000/libgen.scimag25889000-25889999.zip_10.2307/4153991", True), # 特殊扫描版6,只有三页,其中两页是扫描版
("scanned_detection/llm-raw-scihub-o.O-bf01427123", True), # 特殊扫描版8,只有3页且全是大图扫描版
("scihub/scihub_53700000/libgen.scimag53724000-53724999.zip_10.1097/00129191-200509000-00018", True), # 特殊文本版,有一长条,但是只有一条
])
def test_classify_by_img_narrow_strips(book_name, expected_bool_classify_by_img_narrow_strips):
test_data = get_test_json_data(current_directory, "test_metascan_classify_data.json")
img_sz_list = test_data[book_name]["expected_image_info"]
docs = get_docs_from_test_pdf(book_name)
median_width, median_height = get_pdf_page_size_pts(docs)
page_width = int(median_width)
page_height = int(median_height)
bool_classify_by_img_narrow_strips = classify_by_img_narrow_strips(page_width, page_height, img_sz_list)
# assert bool_classify_by_img_narrow_strips == expected_bool_classify_by_img_narrow_strips
\ No newline at end of file
import io
import json
import os
import fitz
import boto3
from botocore.config import Config
from magic_pdf.libs.config_reader import get_s3_config_dict
from magic_pdf.libs.commons import join_path, json_dump_path, read_file, parse_bucket_key
from loguru import logger
test_pdf_dir_path = "s3://llm-pdf-text/unittest/pdf/"
def get_test_pdf_json(book_name):
json_path = join_path(json_dump_path, book_name + ".json")
s3_config = get_s3_config_dict(json_path)
file_content = read_file(json_path, s3_config)
json_str = file_content.decode('utf-8')
json_object = json.loads(json_str)
return json_object
def read_test_file(book_name):
test_pdf_path = join_path(test_pdf_dir_path, book_name + ".pdf")
s3_config = get_s3_config_dict(test_pdf_path)
try:
file_content = read_file(test_pdf_path, s3_config)
return file_content
except Exception as e:
if "NoSuchKey" in str(e):
logger.warning("File not found in test_pdf_path. Downloading from orig_s3_pdf_path.")
try:
json_object = get_test_pdf_json(book_name)
orig_s3_pdf_path = json_object.get('file_location')
s3_config = get_s3_config_dict(orig_s3_pdf_path)
file_content = read_file(orig_s3_pdf_path, s3_config)
s3_client = get_s3_client(test_pdf_path)
bucket_name, bucket_key = parse_bucket_key(test_pdf_path)
file_obj = io.BytesIO(file_content)
s3_client.upload_fileobj(file_obj, bucket_name, bucket_key)
return file_content
except Exception as e:
logger.exception(e)
else:
logger.exception(e)
def get_docs_from_test_pdf(book_name):
file_content = read_test_file(book_name)
return fitz.open("pdf", file_content)
def get_test_json_data(directory_path, json_file_name):
with open(os.path.join(directory_path, json_file_name), "r", encoding='utf-8') as f:
test_data = json.load(f)
return test_data
def get_s3_client(path):
s3_config = get_s3_config_dict(path)
try:
return boto3.client(
"s3",
aws_access_key_id=s3_config["ak"],
aws_secret_access_key=s3_config["sk"],
endpoint_url=s3_config["endpoint"],
config=Config(s3={"addressing_style": "path"}, retries={"max_attempts": 8, "mode": "standard"}),
)
except:
# older boto3 do not support retries.mode param.
return boto3.client(
"s3",
aws_access_key_id=s3_config["ak"],
aws_secret_access_key=s3_config["sk"],
endpoint_url=s3_config["endpoint"],
config=Config(s3={"addressing_style": "path"}, retries={"max_attempts": 8}),
)
import os
import pytest
from magic_pdf.filter.pdf_meta_scan import get_pdf_page_size_pts, get_image_info, get_pdf_text_layout_per_page, get_language
from test_commons import get_docs_from_test_pdf, get_test_json_data
# 获取当前目录
current_directory = os.path.dirname(os.path.abspath(__file__))
'''
获取pdf的宽与高,宽和高各用一个list,分别取中位数
'''
@pytest.mark.parametrize("book_name, expected_width, expected_height",
[
("zlib/zlib_17058115", 795, 1002), # pdf中最大页与最小页差异极大个例
("the_eye/the_eye_wtl_00023799", 616, 785) # 采样的前50页存在中位数大小页面横竖旋转情况
])
def test_get_pdf_page_size_pts(book_name, expected_width, expected_height):
docs = get_docs_from_test_pdf(book_name)
median_width, median_height = get_pdf_page_size_pts(docs)
# assert int(median_width) == expected_width
# assert int(median_height) == expected_height
'''
获取pdf前50页的图片信息,为了提速,对特殊扫描版1的情况做了过滤,其余情况都正常取图片信息
'''
@pytest.mark.parametrize("book_name",
[
"zlib/zlib_21370453", # 特殊扫描版1,每页都有所有扫描页图片,特点是图占比大,每页展示1至n张
"the_eye/the_eye_cdn_00391653", # 特殊文字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张,这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
"scihub/scihub_08400000/libgen.scimag08489000-08489999.zip_10.1016/0370-1573(90)90070-i", # 扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
"zlib/zlib_17216416", # 特殊扫描版3,有的页面是一整张大图,有的页面是通过一条条小图拼起来的
"the_eye/the_eye_wtl_00023799", # 特殊扫描版4,每一页都是一张张小图拼出来的
"the_eye/the_eye_cdn_00328381", # 特殊扫描版5,每一页都是一张张小图拼出来的,但是存在多个小图多次重复使用情况
"scihub/scihub_25800000/libgen.scimag25889000-25889999.zip_10.2307/4153991", # 特殊扫描版6,只有3页且其中两页是扫描页
"scanned_detection/llm-raw-scihub-o.O-0584-8539%2891%2980165-f", # 特殊扫描版7,只有一页,且是一张张小图拼出来的
"scanned_detection/llm-raw-scihub-o.O-bf01427123", # 特殊扫描版8,只有3页且全是大图扫描版
"zlib/zlib_22115997", # 特殊扫描版9,类似特1,但是每页数量不完全相等
"zlib/zlib_21814957", # 特殊扫描版10,类似特1,但是每页数量不完全相等
"zlib/zlib_21814955", # 特殊扫描版11,类似特1,但是每页数量不完全相等
"scihub/scihub_41200000/libgen.scimag41253000-41253999.zip_10.1080/00222938709460256", # 特殊扫描版12,头两页文字版且有一页没图片,后面扫描版11页
"scihub/scihub_37000000/libgen.scimag37068000-37068999.zip_10.1080/0015587X.1936.9718622" # 特殊扫描版13,头两页文字版且有一页没图片,后面扫描版3页
])
def test_get_image_info(book_name):
test_data = get_test_json_data(current_directory, "test_metascan_classify_data.json")
docs = get_docs_from_test_pdf(book_name)
page_width_pts, page_height_pts = get_pdf_page_size_pts(docs)
image_info, junk_img_bojids = get_image_info(docs, page_width_pts, page_height_pts)
# assert image_info == test_data[book_name]["expected_image_info"]
# assert junk_img_bojids == test_data[book_name]["expected_junk_img_bojids"]
'''
获取pdf前50页的文本布局信息,输出list,每个元素为一个页面的横竖排信息
'''
@pytest.mark.parametrize("book_name",
[
"vertical_detection/三国演义_繁体竖排版", # 竖排版本1
"vertical_detection/净空法师_大乘无量寿", # 竖排版本2
"vertical_detection/om3006239", # 横排版本1
"vertical_detection/isit.2006.261791" # 横排版本2
])
def test_get_text_layout_info(book_name):
test_data = get_test_json_data(current_directory, "test_metascan_classify_data.json")
docs = get_docs_from_test_pdf(book_name)
text_layout_info = get_pdf_text_layout_per_page(docs)
# assert text_layout_info == test_data[book_name]["expected_text_layout"]
'''
获取pdf的语言信息
'''
@pytest.mark.parametrize("book_name, expected_language",
[
("scihub/scihub_05000000/libgen.scimag05023000-05023999.zip_10.1034/j.1601-0825.2003.02933.x", "en"), # 英文论文
])
def test_get_text_language_info(book_name, expected_language):
docs = get_docs_from_test_pdf(book_name)
text_language = get_language(docs)
# assert text_language == expected_language
This diff is collapsed.
This diff is collapsed.
import json
from magic_pdf.data.read_api import read_local_pdfs
from magic_pdf.model.magic_model import MagicModel
def test_magic_model_image_v2():
datasets = read_local_pdfs('tests/unittest/test_model/assets/test_01.pdf')
with open('tests/unittest/test_model/assets/test_01.model.json') as f:
model_json = json.load(f)
magic_model = MagicModel(model_json, datasets[0])
imgs = magic_model.get_imgs_v2(0)
print(imgs)
tables = magic_model.get_tables_v2(0)
print(tables)
def test_magic_model_table_v2():
datasets = read_local_pdfs('tests/unittest/test_model/assets/test_02.pdf')
with open('tests/unittest/test_model/assets/test_02.model.json') as f:
model_json = json.load(f)
magic_model = MagicModel(model_json, datasets[0])
tables = magic_model.get_tables_v2(5)
print(tables)
tables = magic_model.get_tables_v2(8)
print(tables)
import unittest
import os
from PIL import Image
from lxml import etree
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
class TestppTableModel(unittest.TestCase):
def test_image2html(self):
img = Image.open(os.path.join(os.path.dirname(__file__), "assets/table.jpg"))
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang='ch'
)
table_model = RapidTableModel(ocr_engine, 'slanet_plus')
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(img)
# 验证生成的 HTML 是否符合预期
parser = etree.HTMLParser()
tree = etree.fromstring(html_code, parser)
# 检查 HTML 结构
assert tree.find('.//table') is not None, "HTML should contain a <table> element"
assert tree.find('.//tr') is not None, "HTML should contain a <tr> element"
assert tree.find('.//td') is not None, "HTML should contain a <td> element"
# 检查具体的表格内容
headers = tree.xpath('//table/tr[1]/td')
assert len(headers) == 5, "Thead should have 5 columns"
assert headers[0].text and headers[0].text.strip() == "Methods", "First header should be 'Methods'"
assert headers[1].text and headers[1].text.strip() == "R", "Second header should be 'R'"
assert headers[2].text and headers[2].text.strip() == "P", "Third header should be 'P'"
assert headers[3].text and headers[3].text.strip() == "F", "Fourth header should be 'F'"
assert headers[4].text and headers[4].text.strip() == "FPS", "Fifth header should be 'FPS'"
# 检查第一行数据
first_row = tree.xpath('//table/tr[2]/td')
assert len(first_row) == 5, "First row should have 5 cells"
assert first_row[0].text and 'SegLink' in first_row[0].text.strip(), "First cell should be 'SegLink [26]'"
assert first_row[1].text and first_row[1].text.strip() == "70.0", "Second cell should be '70.0'"
assert first_row[2].text and first_row[2].text.strip() == "86.0", "Third cell should be '86.0'"
assert first_row[3].text and first_row[3].text.strip() == "77.0", "Fourth cell should be '77.0'"
assert first_row[4].text and first_row[4].text.strip() == "8.9", "Fifth cell should be '8.9'"
# 检查倒数第二行数据
second_last_row = tree.xpath('//table/tr[position()=last()-1]/td')
assert len(second_last_row) == 5, "second_last_row should have 5 cells"
assert second_last_row[0].text and second_last_row[0].text.strip() == "Ours (SynText)", "First cell should be 'Ours (SynText)'"
assert second_last_row[1].text and second_last_row[1].text.strip() == "80.68", "Second cell should be '80.68'"
assert second_last_row[2].text and second_last_row[2].text.strip() == "85.40", "Third cell should be '85.40'"
# assert second_last_row[3].text and second_last_row[3].text.strip() == "82.97", "Fourth cell should be '82.97'"
# assert second_last_row[3].text and second_last_row[4].text.strip() == "12.68", "Fifth cell should be '12.68'"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment