"tests/moe.py" did not exist on "d0f07ff728d7e76547e29ed84cbad8c9a4c830b9"
model_init.py 5.31 KB
Newer Older
1
2
from loguru import logger

3
from magic_pdf.config.constants import MODEL_NAME
4
from magic_pdf.model.model_list import AtomicModel
5
6
7
8
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
    DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
    Layoutlmv3_Predictor
9
10
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
11
12
13
14
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
    ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
    RapidTableModel
15
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
16
17
18
19
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
    StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
    TableMasterPaddleModel
20
21
22
23
24
25
26


def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
    elif table_model_type == MODEL_NAME.TABLE_MASTER:
        config = {
27
28
            'model_dir': model_path,
            'device': _device_
29
30
31
32
33
        }
        table_model = TableMasterPaddleModel(config)
    elif table_model_type == MODEL_NAME.RAPID_TABLE:
        table_model = RapidTableModel()
    else:
34
        logger.error('table model type not allow')
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        exit(1)

    return table_model


def mfd_model_init(weight, device='cpu'):
    mfd_model = YOLOv8MFDModel(weight, device)
    return mfd_model


def mfr_model_init(weight_dir, cfg_path, device='cpu'):
    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
    return mfr_model


def layout_model_init(weight, config_file, device):
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


def doclayout_yolo_model_init(weight, device='cpu'):
    model = DocLayoutYOLOModel(weight, device)
    return model


def ocr_model_init(show_log: bool = False,
                   det_db_box_thresh=0.3,
                   lang=None,
                   use_dilation=True,
                   det_db_unclip_ratio=1.8,
                   ):
66
    if lang is not None and lang != '':
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            lang=lang,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    else:
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
            # use_angle_cls=True,
        )
    return model


class AtomModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_atom_model(self, atom_model_name: str, **kwargs):
95

96
97
        lang = kwargs.get('lang', None)
        layout_model_name = kwargs.get('layout_model_name', None)
98
99
100
        table_model_name = kwargs.get('table_model_name', None)

        if atom_model_name in [AtomicModel.OCR]:
101
            key = (atom_model_name, lang)
102
103
104
105
106
107
108
        elif atom_model_name in [AtomicModel.Layout]:
            key = (atom_model_name, layout_model_name)
        elif atom_model_name in [AtomicModel.Table]:
            key = (atom_model_name, table_model_name)
        else:
            key = atom_model_name

109
110
        if key not in self._models:
            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
111
        return self._models[key]
112
113
114
115

def atom_model_init(model_name: str, **kwargs):
    atom_model = None
    if model_name == AtomicModel.Layout:
116
        if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
117
            atom_model = layout_model_init(
118
119
120
                kwargs.get('layout_weights'),
                kwargs.get('layout_config_file'),
                kwargs.get('device')
121
            )
122
        elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
123
            atom_model = doclayout_yolo_model_init(
124
125
                kwargs.get('doclayout_yolo_weights'),
                kwargs.get('device')
126
127
128
            )
    elif model_name == AtomicModel.MFD:
        atom_model = mfd_model_init(
129
130
            kwargs.get('mfd_weights'),
            kwargs.get('device')
131
132
133
        )
    elif model_name == AtomicModel.MFR:
        atom_model = mfr_model_init(
134
135
136
            kwargs.get('mfr_weight_dir'),
            kwargs.get('mfr_cfg_path'),
            kwargs.get('device')
137
138
139
        )
    elif model_name == AtomicModel.OCR:
        atom_model = ocr_model_init(
140
141
            kwargs.get('ocr_show_log'),
            kwargs.get('det_db_box_thresh'),
142
            kwargs.get('lang'),
143
144
145
        )
    elif model_name == AtomicModel.Table:
        atom_model = table_model_init(
146
147
148
149
            kwargs.get('table_model_name'),
            kwargs.get('table_model_path'),
            kwargs.get('table_max_time'),
            kwargs.get('device')
150
151
        )
    else:
152
        logger.error('model name not allow')
153
154
155
        exit(1)

    if atom_model is None:
156
        logger.error('model init failed')
157
158
159
        exit(1)
    else:
        return atom_model