Commit 284cec04 authored by myhloli's avatar myhloli
Browse files

refactor: replace get_file_from_repos with auto_download_and_get_model_root_path in multiple files

parent 57f44dd8
...@@ -10,7 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel ...@@ -10,7 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from ...model.table.rapid_table import RapidTableModel from ...model.table.rapid_table import RapidTableModel
from ...utils.enum_class import ModelPath from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import get_file_from_repos from ...utils.models_download_utils import auto_download_and_get_model_root_path
def table_model_init(lang=None): def table_model_init(lang=None):
...@@ -144,15 +144,13 @@ class MineruPipelineModel: ...@@ -144,15 +144,13 @@ class MineruPipelineModel:
self.mfd_model = atom_model_manager.get_atom_model( self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD, atom_model_name=AtomicModel.MFD,
mfd_weights=str( mfd_weights=str(
get_file_from_repos(ModelPath.yolo_v8_mfd) os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
), ),
device=self.device, device=self.device,
) )
# 初始化公式解析模型 # 初始化公式解析模型
mfr_weight_dir = str( mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
get_file_from_repos(ModelPath.unimernet_small)
)
self.mfr_model = atom_model_manager.get_atom_model( self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
...@@ -164,7 +162,7 @@ class MineruPipelineModel: ...@@ -164,7 +162,7 @@ class MineruPipelineModel:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str( doclayout_yolo_weights=str(
get_file_from_repos(ModelPath.doclayout_yolo) os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
), ),
device=self.device, device=self.device,
) )
......
...@@ -16,7 +16,7 @@ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc ...@@ -16,7 +16,7 @@ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc
from mineru.data.data_reader_writer import FileBasedDataWriter from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from mineru.utils.enum_class import MakeMode from mineru.utils.enum_class import MakeMode
from mineru.utils.models_download_utils import get_file_from_repos from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
pdf_suffixes = [".pdf"] pdf_suffixes = [".pdf"]
...@@ -168,7 +168,7 @@ def do_parse( ...@@ -168,7 +168,7 @@ def do_parse(
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id) pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method) local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir) image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
model_path = get_file_from_repos('/','vlm') model_path = auto_download_and_get_model_root_path('/', 'vlm')
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url) middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
pdf_info = middle_json["pdf_info"] pdf_info = middle_json["pdf_info"]
...@@ -219,10 +219,14 @@ def do_parse( ...@@ -219,10 +219,14 @@ def do_parse(
if __name__ == "__main__": if __name__ == "__main__":
pdf_path = "../../demo/pdfs/demo2.pdf" # pdf_path = "../../demo/pdfs/demo3.pdf"
# pdf_path = "C:/Users/zhaoxiaomeng/Downloads/input_img_0.jpg" pdf_path = "C:/Users/zhaoxiaomeng/Downloads/4546d0e2-ba60-40a5-a17e-b68555cec741.pdf"
try: try:
do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"], end_page_id=1, backend='vlm-huggingface') do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"],
end_page_id=10,
backend='vlm-huggingface'
# backend = 'pipeline'
)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import copy import copy
import os.path import os
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -11,7 +11,7 @@ from loguru import logger ...@@ -11,7 +11,7 @@ from loguru import logger
from mineru.utils.config_reader import get_device from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import ModelPath from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import get_file_from_repos from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
from .tools.infer.predict_system import TextSystem from .tools.infer.predict_system import TextSystem
from .tools.infer import pytorchocr_utility as utility from .tools.infer import pytorchocr_utility as utility
...@@ -77,8 +77,11 @@ class PytorchPaddleOCR(TextSystem): ...@@ -77,8 +77,11 @@ class PytorchPaddleOCR(TextSystem):
config = yaml.safe_load(file) config = yaml.safe_load(file)
det, rec, dict_file = get_model_params(self.lang, config) det, rec, dict_file = get_model_params(self.lang, config)
ocr_models_dir = ModelPath.pytorch_paddle ocr_models_dir = ModelPath.pytorch_paddle
det_model_path = get_file_from_repos(f"{ocr_models_dir}/{det}")
rec_model_path = get_file_from_repos(f"{ocr_models_dir}/{rec}") det_model_path = f"{ocr_models_dir}/{det}"
det_model_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
rec_model_path = f"{ocr_models_dir}/{rec}"
rec_model_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
kwargs['det_model_path'] = det_model_path kwargs['det_model_path'] = det_model_path
kwargs['rec_model_path'] = rec_model_path kwargs['rec_model_path'] = rec_model_path
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file) kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
......
import os
import cv2 import cv2
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from rapid_table import RapidTable, RapidTableInput from rapid_table import RapidTable, RapidTableInput
from mineru.utils.enum_class import ModelPath from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import get_file_from_repos from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
class RapidTableModel(object): class RapidTableModel(object):
def __init__(self, ocr_engine): def __init__(self, ocr_engine):
slanet_plus_model_path = get_file_from_repos(ModelPath.slanet_plus) slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path) input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
self.table_model = RapidTable(input_args) self.table_model = RapidTable(input_args)
self.ocr_engine = ocr_engine self.ocr_engine = ocr_engine
......
...@@ -9,7 +9,7 @@ from loguru import logger ...@@ -9,7 +9,7 @@ from loguru import logger
from mineru.utils.config_reader import get_device from mineru.utils.config_reader import get_device
from mineru.utils.enum_class import BlockType, ModelPath from mineru.utils.enum_class import BlockType, ModelPath
from mineru.utils.models_download_utils import get_file_from_repos from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks): def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
...@@ -188,7 +188,7 @@ def model_init(model_name: str): ...@@ -188,7 +188,7 @@ def model_init(model_name: str):
device = torch.device(device_name) device = torch.device(device_name)
if model_name == 'layoutreader': if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在 # 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_file_from_repos(ModelPath.layout_reader) layoutreader_model_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.layout_reader), ModelPath.layout_reader)
if os.path.exists(layoutreader_model_dir): if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained( model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir layoutreader_model_dir
......
...@@ -5,7 +5,7 @@ from modelscope import snapshot_download as ms_snapshot_download ...@@ -5,7 +5,7 @@ from modelscope import snapshot_download as ms_snapshot_download
from mineru.utils.config_reader import get_local_models_dir from mineru.utils.config_reader import get_local_models_dir
from mineru.utils.enum_class import ModelPath from mineru.utils.enum_class import ModelPath
def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str: def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipeline') -> str:
""" """
支持文件或目录的可靠下载。 支持文件或目录的可靠下载。
- 如果输入文件: 返回本地文件绝对路径 - 如果输入文件: 返回本地文件绝对路径
...@@ -14,7 +14,7 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str: ...@@ -14,7 +14,7 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
:param relative_path: 文件或目录相对路径 :param relative_path: 文件或目录相对路径
:return: 本地文件绝对路径或相对路径 :return: 本地文件绝对路径或相对路径
""" """
model_source = os.getenv('MINERU_MODEL_SOURCE', None) model_source = os.getenv('MINERU_MODEL_SOURCE', "huggingface")
if model_source == 'local': if model_source == 'local':
local_models_config = get_local_models_dir() local_models_config = get_local_models_dir()
...@@ -54,10 +54,10 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str: ...@@ -54,10 +54,10 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
relative_path = relative_path.strip('/') relative_path = relative_path.strip('/')
cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"]) cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
return cache_dir + "/" + relative_path return cache_dir
if __name__ == '__main__': if __name__ == '__main__':
path1 = get_file_from_repos("models/README.md") path1 = "models/README.md"
print("本地文件绝对路径:", path1) root = auto_download_and_get_model_root_path(path1)
path2 = get_file_from_repos("models/OCR/paddleocr_torch/") print("本地文件绝对路径:", os.path.join(root, path1))
print("本地文件绝对路径:", path2) \ No newline at end of file
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment