Commit 8737ebb2 authored by myhloli's avatar myhloli
Browse files

feat: remove model path parameter from vlm_doc_analyze and streamline model loading

parent 9bfb3e9e
......@@ -114,8 +114,7 @@ def do_parse(
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
model_path = auto_download_and_get_model_root_path('/', 'vlm')
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
pdf_info = middle_json["pdf_info"]
......
......@@ -9,6 +9,7 @@ from .base_predictor import BasePredictor
from .predictor import get_predictor
from .token_to_middle_json import result_to_middle_json
from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import auto_download_and_get_model_root_path
class ModelSingleton:
......@@ -28,6 +29,8 @@ class ModelSingleton:
) -> BasePredictor:
key = (backend,)
if key not in self._models:
if not model_path:
model_path = auto_download_and_get_model_root_path("/","vlm")
self._models[key] = get_predictor(
backend=backend,
model_path=model_path,
......@@ -41,7 +44,7 @@ def doc_analyze(
image_writer: DataWriter | None,
predictor: BasePredictor | None = None,
backend="transformers",
model_path=ModelPath.vlm_root_hf,
model_path: str | None = None,
server_url: str | None = None,
):
if predictor is None:
......@@ -53,10 +56,10 @@ def doc_analyze(
# load_images_time = round(time.time() - load_images_start, 2)
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
infer_start = time.time()
# infer_start = time.time()
results = predictor.batch_predict(images=images_base64_list)
infer_time = round(time.time() - infer_start, 2)
logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
# infer_time = round(time.time() - infer_start, 2)
# logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json, results
......@@ -67,7 +70,7 @@ async def aio_doc_analyze(
image_writer: DataWriter | None,
predictor: BasePredictor | None = None,
backend="transformers",
model_path=ModelPath.vlm_root_hf,
model_path: str | None = None,
server_url: str | None = None,
):
if predictor is None:
......
......@@ -16,7 +16,6 @@ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc
from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from mineru.utils.enum_class import MakeMode
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
pdf_suffixes = [".pdf"]
......@@ -173,8 +172,7 @@ def do_parse(
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
model_path = auto_download_and_get_model_root_path('/', 'vlm')
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
pdf_info = middle_json["pdf_info"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment