vlm_analyze.py 3.41 KB
Newer Older
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
1
2
3
4
5
6
# Copyright (c) Opendatalab. All rights reserved.
import time

from loguru import logger

from ...data.data_reader_writer import DataWriter
7
from mineru.utils.pdf_image_tools import load_images_from_pdf
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
8
9
10
from .base_predictor import BasePredictor
from .predictor import get_predictor
from .token_to_middle_json import result_to_middle_json
11
from ...utils.models_download_utils import auto_download_and_get_model_root_path
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_model(
        self,
        backend: str,
        model_path: str | None,
        server_url: str | None,
28
        **kwargs,
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
29
    ) -> BasePredictor:
30
        key = (backend, model_path, server_url)
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
31
        if key not in self._models:
32
            if backend in ['transformers', 'sglang-engine'] and not model_path:
33
                model_path = auto_download_and_get_model_root_path("/","vlm")
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
34
35
36
37
            self._models[key] = get_predictor(
                backend=backend,
                model_path=model_path,
                server_url=server_url,
38
                **kwargs,
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
39
40
41
42
43
44
45
46
            )
        return self._models[key]


def doc_analyze(
    pdf_bytes,
    image_writer: DataWriter | None,
    predictor: BasePredictor | None = None,
47
    backend="transformers",
48
    model_path: str | None = None,
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
49
    server_url: str | None = None,
50
    **kwargs,
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
51
52
):
    if predictor is None:
53
        predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
54

55
    # load_images_start = time.time()
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
56
57
    images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
    images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
58
59
    # load_images_time = round(time.time() - load_images_start, 2)
    # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
60

61
    # infer_start = time.time()
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
62
    results = predictor.batch_predict(images=images_base64_list)
63
64
    # infer_time = round(time.time() - infer_start, 2)
    # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
65
66
67
68
69
70
71
72
73

    middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
    return middle_json, results


async def aio_doc_analyze(
    pdf_bytes,
    image_writer: DataWriter | None,
    predictor: BasePredictor | None = None,
74
    backend="transformers",
75
    model_path: str | None = None,
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
76
    server_url: str | None = None,
77
    **kwargs,
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
78
79
):
    if predictor is None:
80
        predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
81

82
    # load_images_start = time.time()
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
83
84
    images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
    images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
85
86
    # load_images_time = round(time.time() - load_images_start, 2)
    # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
87

88
    # infer_start = time.time()
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
89
    results = await predictor.aio_batch_predict(images=images_base64_list)
90
91
    # infer_time = round(time.time() - infer_start, 2)
    # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
92
    middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
93
    return middle_json, results