model_init.py 7.59 KB
Newer Older
icecraft's avatar
icecraft committed
1
2
import os

3
import torch
4
5
from loguru import logger

6
from magic_pdf.config.constants import MODEL_NAME
7
from magic_pdf.model.model_list import AtomicModel
icecraft's avatar
icecraft committed
8
9
10
11
12
13
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import \
    YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
    DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
    Layoutlmv3_Predictor
14
15
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
16
17

try:
icecraft's avatar
icecraft committed
18
19
20
21
22
23
24
    from magic_pdf_ascend_plugin.libs.license_verifier import (
        LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
        load_license)
    from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import \
        ModifiedPaddleOCR
    from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import \
        RapidTableModel
25
    license_key = load_license()
26
27
    logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
                f' License expired at {license_key["payload"]["date"]["end_date"]}')
28
except Exception as e:
29
30
31
    if isinstance(e, ImportError):
        pass
    elif isinstance(e, LicenseFormatError):
icecraft's avatar
icecraft committed
32
        logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
33
    elif isinstance(e, LicenseSignatureError):
icecraft's avatar
icecraft committed
34
        logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
35
    elif isinstance(e, LicenseExpiredError):
icecraft's avatar
icecraft committed
36
        logger.error('Ascend Plugin: License has expired. Please renew your license.')
37
    elif isinstance(e, FileNotFoundError):
icecraft's avatar
icecraft committed
38
        logger.error('Ascend Plugin: Not found License file.')
39
    else:
icecraft's avatar
icecraft committed
40
        logger.error(f'Ascend Plugin: {e}')
41
42
43
    from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
    # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
    from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
44

icecraft's avatar
icecraft committed
45
46
47
48
49
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
    StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
    TableMasterPaddleModel

50

51
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
52
53
54
55
    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
    elif table_model_type == MODEL_NAME.TABLE_MASTER:
        config = {
56
57
            'model_dir': model_path,
            'device': _device_
58
59
60
        }
        table_model = TableMasterPaddleModel(config)
    elif table_model_type == MODEL_NAME.RAPID_TABLE:
61
        table_model = RapidTableModel(ocr_engine, table_sub_model_name)
62
    else:
63
        logger.error('table model type not allow')
64
65
66
67
68
69
        exit(1)

    return table_model


def mfd_model_init(weight, device='cpu'):
icecraft's avatar
icecraft committed
70
    if str(device).startswith('npu'):
71
        device = torch.device(device)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    mfd_model = YOLOv8MFDModel(weight, device)
    return mfd_model


def mfr_model_init(weight_dir, cfg_path, device='cpu'):
    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
    return mfr_model


def layout_model_init(weight, config_file, device):
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


def doclayout_yolo_model_init(weight, device='cpu'):
icecraft's avatar
icecraft committed
87
    if str(device).startswith('npu'):
88
        device = torch.device(device)
89
90
91
92
    model = DocLayoutYOLOModel(weight, device)
    return model


93
def langdetect_model_init(langdetect_model_weight, device='cpu'):
icecraft's avatar
icecraft committed
94
    if str(device).startswith('npu'):
95
        device = torch.device(device)
96
    model = YOLOv11LangDetModel(langdetect_model_weight, device)
97
98
99
    return model


100
101
102
103
104
105
def ocr_model_init(show_log: bool = False,
                   det_db_box_thresh=0.3,
                   lang=None,
                   use_dilation=True,
                   det_db_unclip_ratio=1.8,
                   ):
106
    if lang is not None and lang != '':
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            lang=lang,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    else:
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    return model


class AtomModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_atom_model(self, atom_model_name: str, **kwargs):
134

135
136
        lang = kwargs.get('lang', None)
        layout_model_name = kwargs.get('layout_model_name', None)
137
138
139
        table_model_name = kwargs.get('table_model_name', None)

        if atom_model_name in [AtomicModel.OCR]:
140
            key = (atom_model_name, lang)
141
142
143
144
145
146
147
        elif atom_model_name in [AtomicModel.Layout]:
            key = (atom_model_name, layout_model_name)
        elif atom_model_name in [AtomicModel.Table]:
            key = (atom_model_name, table_model_name)
        else:
            key = atom_model_name

148
149
        if key not in self._models:
            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
150
        return self._models[key]
151
152
153
154

def atom_model_init(model_name: str, **kwargs):
    atom_model = None
    if model_name == AtomicModel.Layout:
155
        if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
156
            atom_model = layout_model_init(
157
158
159
                kwargs.get('layout_weights'),
                kwargs.get('layout_config_file'),
                kwargs.get('device')
160
            )
161
        elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
162
            atom_model = doclayout_yolo_model_init(
163
164
                kwargs.get('doclayout_yolo_weights'),
                kwargs.get('device')
165
            )
166
167
168
        else:
            logger.error('layout model name not allow')
            exit(1)
169
170
    elif model_name == AtomicModel.MFD:
        atom_model = mfd_model_init(
171
172
            kwargs.get('mfd_weights'),
            kwargs.get('device')
173
174
175
        )
    elif model_name == AtomicModel.MFR:
        atom_model = mfr_model_init(
176
177
178
            kwargs.get('mfr_weight_dir'),
            kwargs.get('mfr_cfg_path'),
            kwargs.get('device')
179
180
181
        )
    elif model_name == AtomicModel.OCR:
        atom_model = ocr_model_init(
182
183
            kwargs.get('ocr_show_log'),
            kwargs.get('det_db_box_thresh'),
184
            kwargs.get('lang'),
185
186
187
        )
    elif model_name == AtomicModel.Table:
        atom_model = table_model_init(
188
189
190
            kwargs.get('table_model_name'),
            kwargs.get('table_model_path'),
            kwargs.get('table_max_time'),
191
            kwargs.get('device'),
192
193
            kwargs.get('ocr_engine'),
            kwargs.get('table_sub_model_name')
194
        )
195
196
197
    elif model_name == AtomicModel.LangDetect:
        if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
            atom_model = langdetect_model_init(
198
                kwargs.get('langdetect_model_weight'),
199
200
201
202
203
                kwargs.get('device')
            )
        else:
            logger.error('langdetect model name not allow')
            exit(1)
204
    else:
205
        logger.error('model name not allow')
206
207
208
        exit(1)

    if atom_model is None:
209
        logger.error('model init failed')
210
211
212
        exit(1)
    else:
        return atom_model