Commit 68338825 authored by myhloli's avatar myhloli
Browse files

refactor: enhance language support and improve document parsing for multiple files

parent 0f21495a
...@@ -145,7 +145,7 @@ def doc_analyze( ...@@ -145,7 +145,7 @@ def doc_analyze(
f'Batch {index + 1}/{len(batch_images)}: ' f'Batch {index + 1}/{len(batch_images)}: '
f'{processed_images_count} pages/{len(images_with_extra_info)} pages' f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
) )
batch_results = may_batch_image_analyze(batch_image, formula_enable, table_enable) batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
results.extend(batch_results) results.extend(batch_results)
# 构建返回结果 # 构建返回结果
...@@ -171,7 +171,7 @@ def doc_analyze( ...@@ -171,7 +171,7 @@ def doc_analyze(
return middle_json_list, infer_results return middle_json_list, infer_results
def may_batch_image_analyze( def batch_image_analyze(
images_with_extra_info: list[(np.ndarray, bool, str)], images_with_extra_info: list[(np.ndarray, bool, str)],
formula_enable=None, formula_enable=None,
table_enable=None): table_enable=None):
...@@ -192,7 +192,7 @@ def may_batch_image_analyze( ...@@ -192,7 +192,7 @@ def may_batch_image_analyze(
if str(device).startswith('npu') or str(device).startswith('cuda'): if str(device).startswith('npu') or str(device).startswith('cuda'):
vram = get_vram(device) vram = get_vram(device)
if vram is not None: if vram is not None:
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(vram))) gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
if gpu_memory >= 16: if gpu_memory >= 16:
batch_ratio = 16 batch_ratio = 16
elif gpu_memory >= 12: elif gpu_memory >= 12:
......
...@@ -41,6 +41,17 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -41,6 +41,17 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
without method specified, huggingface will be used by default.""", without method specified, huggingface will be used by default.""",
default='pipeline', default='pipeline',
) )
@click.option(
'-l',
'--lang',
'lang',
type=click.Choice(['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']),
help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
Without languages specified, 'ch' will be used by default.
""",
default='ch',
)
@click.option( @click.option(
'-u', '-u',
'--url', '--url',
...@@ -68,24 +79,33 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -68,24 +79,33 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
default=None, default=None,
) )
def main(input_path, output_dir, backend, server_url, start_page_id, end_page_id): def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
def parse_doc(path: Path): def parse_doc(path_list: list[Path]):
try: try:
file_name_list = []
pdf_bytes_list = []
lang_list = []
for path in path_list:
file_name = str(Path(path).stem) file_name = str(Path(path).stem)
pdf_bits = read_fn(path) pdf_bytes = read_fn(path)
do_parse(output_dir, file_name, pdf_bits, backend, server_url, file_name_list.append(file_name)
pdf_bytes_list.append(pdf_bytes)
lang_list.append(lang)
do_parse(output_dir, file_name_list, pdf_bytes_list, lang_list, backend, server_url,
start_page_id=start_page_id, end_page_id=end_page_id) start_page_id=start_page_id, end_page_id=end_page_id)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
if os.path.isdir(input_path): if os.path.isdir(input_path):
doc_path_list = []
for doc_path in Path(input_path).glob('*'): for doc_path in Path(input_path).glob('*'):
if doc_path.suffix in pdf_suffixes + image_suffixes: if doc_path.suffix in pdf_suffixes + image_suffixes:
parse_doc(Path(doc_path)) doc_path_list.append(doc_path)
parse_doc(doc_path_list)
else: else:
parse_doc(Path(input_path)) parse_doc([Path(input_path)])
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import yaml import yaml
from loguru import logger from loguru import logger
from magic_pdf.libs.config_reader import get_device, get_local_models_dir from mineru.backend.pipeline.config_reader import get_device, get_local_models_dir
from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
from .tools.infer.predict_system import TextSystem from .tools.infer.predict_system import TextSystem
from .tools.infer import pytorchocr_utility as utility from .tools.infer import pytorchocr_utility as utility
......
from pathlib import Path from pathlib import Path
from setuptools import setup, find_packages from setuptools import setup, find_packages
from magic_pdf.libs.version import __version__ from mineru.version import __version__
def parse_requirements(filename): def parse_requirements(filename):
...@@ -24,13 +24,13 @@ if __name__ == '__main__': ...@@ -24,13 +24,13 @@ if __name__ == '__main__':
'README.md').open(encoding='utf-8') as file: 'README.md').open(encoding='utf-8') as file:
long_description = file.read() long_description = file.read()
setup( setup(
name="magic_pdf", # 项目名 name="mineru", # 项目名
version=__version__, # 自动从tag中获取版本号 version=__version__, # 自动从tag中获取版本号
license="AGPL-3.0", license="AGPL-3.0",
packages=find_packages() + ["magic_pdf.resources"] + ["magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorchocr.utils.resources"], # 包含所有的包 packages=find_packages() + ["mineru.resources"] + ["mineru.model.ocr.paddleocr2pytorch.pytorchocr.utils.resources"], # 包含所有的包
package_data={ package_data={
"magic_pdf.resources": ["**"], # 包含magic_pdf.resources目录下的所有文件 "mineru.resources": ["**"], # 包含magic_pdf.resources目录下的所有文件
"magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorchocr.utils.resources": ["**"], # pytorchocr.resources目录下的所有文件 "mineru.model.ocr.paddleocr2pytorch.pytorchocr.utils.resources": ["**"], # pytorchocr.resources目录下的所有文件
}, },
install_requires=parse_requirements('requirements.txt'), # 项目依赖的第三方库 install_requires=parse_requirements('requirements.txt'), # 项目依赖的第三方库
extras_require={ extras_require={
...@@ -84,8 +84,7 @@ if __name__ == '__main__': ...@@ -84,8 +84,7 @@ if __name__ == '__main__':
python_requires=">=3.10,<3.14", # 项目依赖的 Python 版本 python_requires=">=3.10,<3.14", # 项目依赖的 Python 版本
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"magic-pdf = magic_pdf.tools.cli:cli", "mineru = mineru.cli:client.main", # 命令行入口点,mineru命令将调用mineru.cli.client.main函数
"magic-pdf-dev = magic_pdf.tools.cli_dev:cli"
], ],
}, # 项目提供的可执行命令 }, # 项目提供的可执行命令
include_package_data=True, # 是否包含非代码文件,如数据文件、配置文件等 include_package_data=True, # 是否包含非代码文件,如数据文件、配置文件等
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment