model_init.py 4.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from loguru import logger

from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel

from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel


def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
    elif table_model_type == MODEL_NAME.TABLE_MASTER:
        config = {
            "model_dir": model_path,
            "device": _device_
        }
        table_model = TableMasterPaddleModel(config)
    elif table_model_type == MODEL_NAME.RAPID_TABLE:
        table_model = RapidTableModel()
    else:
        logger.error("table model type not allow")
        exit(1)

    return table_model


def mfd_model_init(weight, device='cpu'):
    mfd_model = YOLOv8MFDModel(weight, device)
    return mfd_model


def mfr_model_init(weight_dir, cfg_path, device='cpu'):
    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
    return mfr_model


def layout_model_init(weight, config_file, device):
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


def doclayout_yolo_model_init(weight, device='cpu'):
    model = DocLayoutYOLOModel(weight, device)
    return model


def ocr_model_init(show_log: bool = False,
                   det_db_box_thresh=0.3,
                   lang=None,
                   use_dilation=True,
                   det_db_unclip_ratio=1.8,
                   ):
    if lang is not None:
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            lang=lang,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    else:
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
            # use_angle_cls=True,
        )
    return model


class AtomModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_atom_model(self, atom_model_name: str, **kwargs):
        lang = kwargs.get("lang", None)
        layout_model_name = kwargs.get("layout_model_name", None)
        key = (atom_model_name, layout_model_name, lang)
        if key not in self._models:
            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
        return self._models[key]


def atom_model_init(model_name: str, **kwargs):
    atom_model = None
    if model_name == AtomicModel.Layout:
        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
            atom_model = layout_model_init(
                kwargs.get("layout_weights"),
                kwargs.get("layout_config_file"),
                kwargs.get("device")
            )
        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
            atom_model = doclayout_yolo_model_init(
                kwargs.get("doclayout_yolo_weights"),
                kwargs.get("device")
            )
    elif model_name == AtomicModel.MFD:
        atom_model = mfd_model_init(
            kwargs.get("mfd_weights"),
            kwargs.get("device")
        )
    elif model_name == AtomicModel.MFR:
        atom_model = mfr_model_init(
            kwargs.get("mfr_weight_dir"),
            kwargs.get("mfr_cfg_path"),
            kwargs.get("device")
        )
    elif model_name == AtomicModel.OCR:
        atom_model = ocr_model_init(
            kwargs.get("ocr_show_log"),
            kwargs.get("det_db_box_thresh"),
            kwargs.get("lang")
        )
    elif model_name == AtomicModel.Table:
        atom_model = table_model_init(
            kwargs.get("table_model_name"),
            kwargs.get("table_model_path"),
            kwargs.get("table_max_time"),
            kwargs.get("device")
        )
    else:
        logger.error("model name not allow")
        exit(1)

    if atom_model is None:
        logger.error("model init failed")
        exit(1)
    else:
        return atom_model