vlm_analyze.py 3.27 KB
Newer Older
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
1
2
3
4
5
6
# Copyright (c) Opendatalab. All rights reserved.
import time

from loguru import logger

from ...data.data_reader_writer import DataWriter
7
from mineru.utils.pdf_image_tools import load_images_from_pdf
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
8
9
10
from .base_predictor import BasePredictor
from .predictor import get_predictor
from .token_to_middle_json import result_to_middle_json
11
from ...utils.enum_class import ModelPath
12
from ...utils.models_download_utils import auto_download_and_get_model_root_path
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_model(
        self,
        backend: str,
        model_path: str | None,
        server_url: str | None,
    ) -> BasePredictor:
        key = (backend,)
        if key not in self._models:
32
33
            if not model_path:
                model_path = auto_download_and_get_model_root_path("/","vlm")
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
34
35
36
37
38
39
40
41
42
43
44
45
            self._models[key] = get_predictor(
                backend=backend,
                model_path=model_path,
                server_url=server_url,
            )
        return self._models[key]


def doc_analyze(
    pdf_bytes,
    image_writer: DataWriter | None,
    predictor: BasePredictor | None = None,
46
    backend="transformers",
47
    model_path: str | None = None,
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
48
49
50
51
52
    server_url: str | None = None,
):
    if predictor is None:
        predictor = ModelSingleton().get_model(backend, model_path, server_url)

53
    # load_images_start = time.time()
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
54
55
    images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
    images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
56
57
    # load_images_time = round(time.time() - load_images_start, 2)
    # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
58

59
    # infer_start = time.time()
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
60
    results = predictor.batch_predict(images=images_base64_list)
61
62
    # infer_time = round(time.time() - infer_start, 2)
    # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
63
64
65
66
67
68
69
70
71

    middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
    return middle_json, results


async def aio_doc_analyze(
    pdf_bytes,
    image_writer: DataWriter | None,
    predictor: BasePredictor | None = None,
72
    backend="transformers",
73
    model_path: str | None = None,
Jin Zhen Jiang's avatar
Jin Zhen Jiang committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    server_url: str | None = None,
):
    if predictor is None:
        predictor = ModelSingleton().get_model(backend, model_path, server_url)

    load_images_start = time.time()
    images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
    images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
    load_images_time = round(time.time() - load_images_start, 2)
    logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")

    infer_start = time.time()
    results = await predictor.aio_batch_predict(images=images_base64_list)
    infer_time = round(time.time() - infer_start, 2)
    logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
    middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
    return middle_json