model_init.py 7.84 KB
Newer Older
1
import torch
2
3
from loguru import logger

4
from magic_pdf.config.constants import MODEL_NAME
5
from magic_pdf.model.model_list import AtomicModel
6
7
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
8
9
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
# try:
#     from magic_pdf_ascend_plugin.libs.license_verifier import (
#         LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
#         load_license)
#     from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
#     from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
#     license_key = load_license()
#     logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
#                 f' License expired at {license_key["payload"]["date"]["end_date"]}')
# except Exception as e:
#     if isinstance(e, ImportError):
#         pass
#     elif isinstance(e, LicenseFormatError):
#         logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
#     elif isinstance(e, LicenseSignatureError):
#         logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
#     elif isinstance(e, LicenseExpiredError):
#         logger.error('Ascend Plugin: License has expired. Please renew your license.')
#     elif isinstance(e, FileNotFoundError):
#         logger.error('Ascend Plugin: Not found License file.')
#     else:
#         logger.error(f'Ascend Plugin: {e}')
#     from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
#     # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
#     from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
37
38


39
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
40
    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
41
        from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
42
43
        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
    elif table_model_type == MODEL_NAME.TABLE_MASTER:
44
        from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
45
        config = {
46
47
            'model_dir': model_path,
            'device': _device_
48
49
50
        }
        table_model = TableMasterPaddleModel(config)
    elif table_model_type == MODEL_NAME.RAPID_TABLE:
51
        table_model = RapidTableModel(ocr_engine, table_sub_model_name)
52
    else:
53
        logger.error('table model type not allow')
54
55
56
57
58
59
        exit(1)

    return table_model


def mfd_model_init(weight, device='cpu'):
icecraft's avatar
icecraft committed
60
    if str(device).startswith('npu'):
61
        device = torch.device(device)
62
63
64
65
66
67
68
69
70
71
    mfd_model = YOLOv8MFDModel(weight, device)
    return mfd_model


def mfr_model_init(weight_dir, cfg_path, device='cpu'):
    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
    return mfr_model


def layout_model_init(weight, config_file, device):
72
    from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
73
74
75
76
77
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


def doclayout_yolo_model_init(weight, device='cpu'):
icecraft's avatar
icecraft committed
78
    if str(device).startswith('npu'):
79
        device = torch.device(device)
80
81
82
83
    model = DocLayoutYOLOModel(weight, device)
    return model


84
def langdetect_model_init(langdetect_model_weight, device='cpu'):
icecraft's avatar
icecraft committed
85
    if str(device).startswith('npu'):
86
        device = torch.device(device)
87
    model = YOLOv11LangDetModel(langdetect_model_weight, device)
88
89
90
    return model


91
92
93
94
95
96
def ocr_model_init(show_log: bool = False,
                   det_db_box_thresh=0.3,
                   lang=None,
                   use_dilation=True,
                   det_db_unclip_ratio=1.8,
                   ):
97
    if lang is not None and lang != '':
98
99
        # model = ModifiedPaddleOCR(
        model = PytorchPaddleOCR(
100
101
102
103
104
105
106
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            lang=lang,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    else:
107
108
        # model = ModifiedPaddleOCR(
        model = PytorchPaddleOCR(
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    return model


class AtomModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_atom_model(self, atom_model_name: str, **kwargs):
127

128
129
        lang = kwargs.get('lang', None)
        layout_model_name = kwargs.get('layout_model_name', None)
130
131
132
        table_model_name = kwargs.get('table_model_name', None)

        if atom_model_name in [AtomicModel.OCR]:
133
            key = (atom_model_name, lang)
134
135
136
137
138
139
140
        elif atom_model_name in [AtomicModel.Layout]:
            key = (atom_model_name, layout_model_name)
        elif atom_model_name in [AtomicModel.Table]:
            key = (atom_model_name, table_model_name)
        else:
            key = atom_model_name

141
142
        if key not in self._models:
            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
143
        return self._models[key]
144
145
146
147

def atom_model_init(model_name: str, **kwargs):
    atom_model = None
    if model_name == AtomicModel.Layout:
148
        if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
149
            atom_model = layout_model_init(
150
151
152
                kwargs.get('layout_weights'),
                kwargs.get('layout_config_file'),
                kwargs.get('device')
153
            )
154
        elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
155
            atom_model = doclayout_yolo_model_init(
156
157
                kwargs.get('doclayout_yolo_weights'),
                kwargs.get('device')
158
            )
159
160
161
        else:
            logger.error('layout model name not allow')
            exit(1)
162
163
    elif model_name == AtomicModel.MFD:
        atom_model = mfd_model_init(
164
165
            kwargs.get('mfd_weights'),
            kwargs.get('device')
166
167
168
        )
    elif model_name == AtomicModel.MFR:
        atom_model = mfr_model_init(
169
170
171
            kwargs.get('mfr_weight_dir'),
            kwargs.get('mfr_cfg_path'),
            kwargs.get('device')
172
173
174
        )
    elif model_name == AtomicModel.OCR:
        atom_model = ocr_model_init(
175
176
            kwargs.get('ocr_show_log'),
            kwargs.get('det_db_box_thresh'),
177
            kwargs.get('lang'),
178
179
180
        )
    elif model_name == AtomicModel.Table:
        atom_model = table_model_init(
181
182
183
            kwargs.get('table_model_name'),
            kwargs.get('table_model_path'),
            kwargs.get('table_max_time'),
184
            kwargs.get('device'),
185
186
            kwargs.get('ocr_engine'),
            kwargs.get('table_sub_model_name')
187
        )
188
189
190
    elif model_name == AtomicModel.LangDetect:
        if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
            atom_model = langdetect_model_init(
191
                kwargs.get('langdetect_model_weight'),
192
193
194
195
196
                kwargs.get('device')
            )
        else:
            logger.error('langdetect model name not allow')
            exit(1)
197
    else:
198
        logger.error('model name not allow')
199
200
201
        exit(1)

    if atom_model is None:
202
        logger.error('model init failed')
203
204
205
        exit(1)
    else:
        return atom_model