"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "9026b86d8a4c5eacab3cc9464654da70a772d328"
Commit 8737ebb2 authored by myhloli's avatar myhloli
Browse files

feat: remove model path parameter from vlm_doc_analyze and streamline model loading

parent 9bfb3e9e
...@@ -114,8 +114,7 @@ def do_parse( ...@@ -114,8 +114,7 @@ def do_parse(
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id) pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method) local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir) image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
model_path = auto_download_and_get_model_root_path('/', 'vlm') middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
pdf_info = middle_json["pdf_info"] pdf_info = middle_json["pdf_info"]
......
...@@ -9,6 +9,7 @@ from .base_predictor import BasePredictor ...@@ -9,6 +9,7 @@ from .base_predictor import BasePredictor
from .predictor import get_predictor from .predictor import get_predictor
from .token_to_middle_json import result_to_middle_json from .token_to_middle_json import result_to_middle_json
from ...utils.enum_class import ModelPath from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import auto_download_and_get_model_root_path
class ModelSingleton: class ModelSingleton:
...@@ -28,6 +29,8 @@ class ModelSingleton: ...@@ -28,6 +29,8 @@ class ModelSingleton:
) -> BasePredictor: ) -> BasePredictor:
key = (backend,) key = (backend,)
if key not in self._models: if key not in self._models:
if not model_path:
model_path = auto_download_and_get_model_root_path("/","vlm")
self._models[key] = get_predictor( self._models[key] = get_predictor(
backend=backend, backend=backend,
model_path=model_path, model_path=model_path,
...@@ -41,7 +44,7 @@ def doc_analyze( ...@@ -41,7 +44,7 @@ def doc_analyze(
image_writer: DataWriter | None, image_writer: DataWriter | None,
predictor: BasePredictor | None = None, predictor: BasePredictor | None = None,
backend="transformers", backend="transformers",
model_path=ModelPath.vlm_root_hf, model_path: str | None = None,
server_url: str | None = None, server_url: str | None = None,
): ):
if predictor is None: if predictor is None:
...@@ -53,10 +56,10 @@ def doc_analyze( ...@@ -53,10 +56,10 @@ def doc_analyze(
# load_images_time = round(time.time() - load_images_start, 2) # load_images_time = round(time.time() - load_images_start, 2)
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s") # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
infer_start = time.time() # infer_start = time.time()
results = predictor.batch_predict(images=images_base64_list) results = predictor.batch_predict(images=images_base64_list)
infer_time = round(time.time() - infer_start, 2) # infer_time = round(time.time() - infer_start, 2)
logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s") # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer) middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json, results return middle_json, results
...@@ -67,7 +70,7 @@ async def aio_doc_analyze( ...@@ -67,7 +70,7 @@ async def aio_doc_analyze(
image_writer: DataWriter | None, image_writer: DataWriter | None,
predictor: BasePredictor | None = None, predictor: BasePredictor | None = None,
backend="transformers", backend="transformers",
model_path=ModelPath.vlm_root_hf, model_path: str | None = None,
server_url: str | None = None, server_url: str | None = None,
): ):
if predictor is None: if predictor is None:
......
...@@ -16,7 +16,6 @@ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc ...@@ -16,7 +16,6 @@ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc
from mineru.data.data_reader_writer import FileBasedDataWriter from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from mineru.utils.enum_class import MakeMode from mineru.utils.enum_class import MakeMode
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
pdf_suffixes = [".pdf"] pdf_suffixes = [".pdf"]
...@@ -173,8 +172,7 @@ def do_parse( ...@@ -173,8 +172,7 @@ def do_parse(
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id) pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method) local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir) image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
model_path = auto_download_and_get_model_root_path('/', 'vlm') middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
pdf_info = middle_json["pdf_info"] pdf_info = middle_json["pdf_info"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment