model_init.py 8.11 KB
Newer Older
1
import torch
2
3
from loguru import logger

4
from magic_pdf.config.constants import MODEL_NAME
5
from magic_pdf.model.model_list import AtomicModel
6
7
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
8
9
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
# try:
#     from magic_pdf_ascend_plugin.libs.license_verifier import (
#         LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
#         load_license)
#     from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
#     from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
#     license_key = load_license()
#     logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
#                 f' License expired at {license_key["payload"]["date"]["end_date"]}')
# except Exception as e:
#     if isinstance(e, ImportError):
#         pass
#     elif isinstance(e, LicenseFormatError):
#         logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
#     elif isinstance(e, LicenseSignatureError):
#         logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
#     elif isinstance(e, LicenseExpiredError):
#         logger.error('Ascend Plugin: License has expired. Please renew your license.')
#     elif isinstance(e, FileNotFoundError):
#         logger.error('Ascend Plugin: Not found License file.')
#     else:
#         logger.error(f'Ascend Plugin: {e}')
#     from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
#     # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
#     from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
37
38


39
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
40
    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
41
        from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
42
43
        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
    elif table_model_type == MODEL_NAME.TABLE_MASTER:
44
        from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
45
        config = {
46
47
            'model_dir': model_path,
            'device': _device_
48
49
50
        }
        table_model = TableMasterPaddleModel(config)
    elif table_model_type == MODEL_NAME.RAPID_TABLE:
51
52
53
54
55
56
57
58
        atom_model_manager = AtomModelSingleton()
        ocr_engine = atom_model_manager.get_atom_model(
            atom_model_name='ocr',
            ocr_show_log=False,
            det_db_box_thresh=0.5,
            det_db_unclip_ratio=1.6,
            lang=lang
        )
59
        table_model = RapidTableModel(ocr_engine, table_sub_model_name)
60
    else:
61
        logger.error('table model type not allow')
62
63
64
65
66
67
        exit(1)

    return table_model


def mfd_model_init(weight, device='cpu'):
icecraft's avatar
icecraft committed
68
    if str(device).startswith('npu'):
69
        device = torch.device(device)
70
71
72
73
74
75
76
77
78
79
    mfd_model = YOLOv8MFDModel(weight, device)
    return mfd_model


def mfr_model_init(weight_dir, cfg_path, device='cpu'):
    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
    return mfr_model


def layout_model_init(weight, config_file, device):
80
    from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
81
82
83
84
85
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


def doclayout_yolo_model_init(weight, device='cpu'):
icecraft's avatar
icecraft committed
86
    if str(device).startswith('npu'):
87
        device = torch.device(device)
88
89
90
91
    model = DocLayoutYOLOModel(weight, device)
    return model


92
def langdetect_model_init(langdetect_model_weight, device='cpu'):
icecraft's avatar
icecraft committed
93
    if str(device).startswith('npu'):
94
        device = torch.device(device)
95
    model = YOLOv11LangDetModel(langdetect_model_weight, device)
96
97
98
    return model


99
100
101
102
103
104
def ocr_model_init(show_log: bool = False,
                   det_db_box_thresh=0.3,
                   lang=None,
                   use_dilation=True,
                   det_db_unclip_ratio=1.8,
                   ):
105
    if lang is not None and lang != '':
106
107
        # model = ModifiedPaddleOCR(
        model = PytorchPaddleOCR(
108
109
110
111
112
113
114
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            lang=lang,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    else:
115
116
        # model = ModifiedPaddleOCR(
        model = PytorchPaddleOCR(
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    return model


class AtomModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_atom_model(self, atom_model_name: str, **kwargs):
135

136
137
        lang = kwargs.get('lang', None)
        layout_model_name = kwargs.get('layout_model_name', None)
138
139
140
        table_model_name = kwargs.get('table_model_name', None)

        if atom_model_name in [AtomicModel.OCR]:
141
            key = (atom_model_name, lang)
142
143
144
        elif atom_model_name in [AtomicModel.Layout]:
            key = (atom_model_name, layout_model_name)
        elif atom_model_name in [AtomicModel.Table]:
145
            key = (atom_model_name, table_model_name, lang)
146
147
148
        else:
            key = atom_model_name

149
150
        if key not in self._models:
            self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
151
        return self._models[key]
152
153
154
155

def atom_model_init(model_name: str, **kwargs):
    atom_model = None
    if model_name == AtomicModel.Layout:
156
        if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
157
            atom_model = layout_model_init(
158
159
160
                kwargs.get('layout_weights'),
                kwargs.get('layout_config_file'),
                kwargs.get('device')
161
            )
162
        elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
163
            atom_model = doclayout_yolo_model_init(
164
165
                kwargs.get('doclayout_yolo_weights'),
                kwargs.get('device')
166
            )
167
168
169
        else:
            logger.error('layout model name not allow')
            exit(1)
170
171
    elif model_name == AtomicModel.MFD:
        atom_model = mfd_model_init(
172
173
            kwargs.get('mfd_weights'),
            kwargs.get('device')
174
175
176
        )
    elif model_name == AtomicModel.MFR:
        atom_model = mfr_model_init(
177
178
179
            kwargs.get('mfr_weight_dir'),
            kwargs.get('mfr_cfg_path'),
            kwargs.get('device')
180
181
182
        )
    elif model_name == AtomicModel.OCR:
        atom_model = ocr_model_init(
183
184
            kwargs.get('ocr_show_log'),
            kwargs.get('det_db_box_thresh'),
185
            kwargs.get('lang'),
186
187
188
        )
    elif model_name == AtomicModel.Table:
        atom_model = table_model_init(
189
190
191
            kwargs.get('table_model_name'),
            kwargs.get('table_model_path'),
            kwargs.get('table_max_time'),
192
            kwargs.get('device'),
193
            kwargs.get('lang'),
194
            kwargs.get('table_sub_model_name')
195
        )
196
197
198
    elif model_name == AtomicModel.LangDetect:
        if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
            atom_model = langdetect_model_init(
199
                kwargs.get('langdetect_model_weight'),
200
201
202
203
204
                kwargs.get('device')
            )
        else:
            logger.error('langdetect model name not allow')
            exit(1)
205
    else:
206
        logger.error('model name not allow')
207
208
209
        exit(1)

    if atom_model is None:
210
        logger.error('model init failed')
211
212
213
        exit(1)
    else:
        return atom_model