Commit 284cec04 authored by myhloli's avatar myhloli
Browse files

refactor: replace get_file_from_repos with auto_download_and_get_model_root_path in multiple files

parent 57f44dd8
......@@ -10,7 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from ...model.table.rapid_table import RapidTableModel
from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import get_file_from_repos
from ...utils.models_download_utils import auto_download_and_get_model_root_path
def table_model_init(lang=None):
......@@ -144,15 +144,13 @@ class MineruPipelineModel:
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(
get_file_from_repos(ModelPath.yolo_v8_mfd)
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = str(
get_file_from_repos(ModelPath.unimernet_small)
)
mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
......@@ -164,7 +162,7 @@ class MineruPipelineModel:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str(
get_file_from_repos(ModelPath.doclayout_yolo)
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
),
device=self.device,
)
......
......@@ -16,7 +16,7 @@ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc
from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from mineru.utils.enum_class import MakeMode
from mineru.utils.models_download_utils import get_file_from_repos
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
pdf_suffixes = [".pdf"]
......@@ -168,7 +168,7 @@ def do_parse(
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
model_path = get_file_from_repos('/','vlm')
model_path = auto_download_and_get_model_root_path('/', 'vlm')
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
pdf_info = middle_json["pdf_info"]
......@@ -219,10 +219,14 @@ def do_parse(
if __name__ == "__main__":
pdf_path = "../../demo/pdfs/demo2.pdf"
# pdf_path = "C:/Users/zhaoxiaomeng/Downloads/input_img_0.jpg"
# pdf_path = "../../demo/pdfs/demo3.pdf"
pdf_path = "C:/Users/zhaoxiaomeng/Downloads/4546d0e2-ba60-40a5-a17e-b68555cec741.pdf"
try:
do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"], end_page_id=1, backend='vlm-huggingface')
do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"],
end_page_id=10,
backend='vlm-huggingface'
# backend = 'pipeline'
)
except Exception as e:
logger.exception(e)
# Copyright (c) Opendatalab. All rights reserved.
import copy
import os.path
import os
import warnings
from pathlib import Path
......@@ -11,7 +11,7 @@ from loguru import logger
from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import get_file_from_repos
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
from .tools.infer.predict_system import TextSystem
from .tools.infer import pytorchocr_utility as utility
......@@ -77,8 +77,11 @@ class PytorchPaddleOCR(TextSystem):
config = yaml.safe_load(file)
det, rec, dict_file = get_model_params(self.lang, config)
ocr_models_dir = ModelPath.pytorch_paddle
det_model_path = get_file_from_repos(f"{ocr_models_dir}/{det}")
rec_model_path = get_file_from_repos(f"{ocr_models_dir}/{rec}")
det_model_path = f"{ocr_models_dir}/{det}"
det_model_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
rec_model_path = f"{ocr_models_dir}/{rec}"
rec_model_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
kwargs['det_model_path'] = det_model_path
kwargs['rec_model_path'] = rec_model_path
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
......
import os
import cv2
import numpy as np
from loguru import logger
from rapid_table import RapidTable, RapidTableInput
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import get_file_from_repos
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
class RapidTableModel(object):
def __init__(self, ocr_engine):
slanet_plus_model_path = get_file_from_repos(ModelPath.slanet_plus)
slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
self.table_model = RapidTable(input_args)
self.ocr_engine = ocr_engine
......
......@@ -9,7 +9,7 @@ from loguru import logger
from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import BlockType, ModelPath
from mineru.utils.models_download_utils import get_file_from_repos
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
......@@ -188,7 +188,7 @@ def model_init(model_name: str):
device = torch.device(device_name)
if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_file_from_repos(ModelPath.layout_reader)
layoutreader_model_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.layout_reader), ModelPath.layout_reader)
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir
......
......@@ -5,7 +5,7 @@ from modelscope import snapshot_download as ms_snapshot_download
from mineru.utils.config_reader import get_local_models_dir
from mineru.utils.enum_class import ModelPath
def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipeline') -> str:
"""
支持文件或目录的可靠下载。
- 如果输入文件: 返回本地文件绝对路径
......@@ -14,7 +14,7 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
:param relative_path: 文件或目录相对路径
:return: 本地文件绝对路径或相对路径
"""
model_source = os.getenv('MINERU_MODEL_SOURCE', None)
model_source = os.getenv('MINERU_MODEL_SOURCE', "huggingface")
if model_source == 'local':
local_models_config = get_local_models_dir()
......@@ -54,10 +54,10 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
relative_path = relative_path.strip('/')
cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
return cache_dir + "/" + relative_path
return cache_dir
if __name__ == '__main__':
path1 = get_file_from_repos("models/README.md")
print("本地文件绝对路径:", path1)
path2 = get_file_from_repos("models/OCR/paddleocr_torch/")
print("本地文件绝对路径:", path2)
\ No newline at end of file
path1 = "models/README.md"
root = auto_download_and_get_model_root_path(path1)
print("本地文件绝对路径:", os.path.join(root, path1))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment