model_init.py 5.32 KB
Newer Older
1
2
from loguru import logger

3
from magic_pdf.config.constants import MODEL_NAME
4
from magic_pdf.model.model_list import AtomicModel
5
6
7
8
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
    DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
    Layoutlmv3_Predictor
9
10
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
11
12
13
14
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
    ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
    RapidTableModel
15
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
16
17
18
19
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
    StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
    TableMasterPaddleModel
20
21
22
23
24
25
26


def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
    if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
        table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
    elif table_model_type == MODEL_NAME.TABLE_MASTER:
        config = {
27
28
            'model_dir': model_path,
            'device': _device_
29
30
31
32
33
        }
        table_model = TableMasterPaddleModel(config)
    elif table_model_type == MODEL_NAME.RAPID_TABLE:
        table_model = RapidTableModel()
    else:
34
        logger.error('table model type not allow')
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        exit(1)

    return table_model


def mfd_model_init(weight, device='cpu'):
    mfd_model = YOLOv8MFDModel(weight, device)
    return mfd_model


def mfr_model_init(weight_dir, cfg_path, device='cpu'):
    mfr_model = UnimernetModel(weight_dir, cfg_path, device)
    return mfr_model


def layout_model_init(weight, config_file, device):
    model = Layoutlmv3_Predictor(weight, config_file, device)
    return model


def doclayout_yolo_model_init(weight, device='cpu'):
    model = DocLayoutYOLOModel(weight, device)
    return model


def ocr_model_init(show_log: bool = False,
                   det_db_box_thresh=0.3,
                   lang=None,
                   use_dilation=True,
                   det_db_unclip_ratio=1.8,
                   ):
66
    if lang is not None and lang != '':
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            lang=lang,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
        )
    else:
        model = ModifiedPaddleOCR(
            show_log=show_log,
            det_db_box_thresh=det_db_box_thresh,
            use_dilation=use_dilation,
            det_db_unclip_ratio=det_db_unclip_ratio,
            # use_angle_cls=True,
        )
    return model


85
86
from threading import Lock

87
88
89
class AtomModelSingleton:
    _instance = None
    _models = {}
90
    _lock = Lock()
91
92
93
94
95
96
97

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_atom_model(self, atom_model_name: str, **kwargs):
98
99
        lang = kwargs.get('lang', None)
        layout_model_name = kwargs.get('layout_model_name', None)
100
        key = (atom_model_name, layout_model_name, lang)
101
102
103
104
        if atom_model_name == AtomicModel.OCR:
            with self._lock:
                if key not in self._models:
                    self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
105
106
                else:
                    return self._models[key]
107
108
109
        else:
            if key not in self._models:
                self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
110
111
            else:
                return self._models[key]
112
113
114
115
116


def atom_model_init(model_name: str, **kwargs):
    atom_model = None
    if model_name == AtomicModel.Layout:
117
        if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
118
            atom_model = layout_model_init(
119
120
121
                kwargs.get('layout_weights'),
                kwargs.get('layout_config_file'),
                kwargs.get('device')
122
            )
123
        elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
124
            atom_model = doclayout_yolo_model_init(
125
126
                kwargs.get('doclayout_yolo_weights'),
                kwargs.get('device')
127
128
129
            )
    elif model_name == AtomicModel.MFD:
        atom_model = mfd_model_init(
130
131
            kwargs.get('mfd_weights'),
            kwargs.get('device')
132
133
134
        )
    elif model_name == AtomicModel.MFR:
        atom_model = mfr_model_init(
135
136
137
            kwargs.get('mfr_weight_dir'),
            kwargs.get('mfr_cfg_path'),
            kwargs.get('device')
138
139
140
        )
    elif model_name == AtomicModel.OCR:
        atom_model = ocr_model_init(
141
142
143
            kwargs.get('ocr_show_log'),
            kwargs.get('det_db_box_thresh'),
            kwargs.get('lang')
144
145
146
        )
    elif model_name == AtomicModel.Table:
        atom_model = table_model_init(
147
148
149
150
            kwargs.get('table_model_name'),
            kwargs.get('table_model_path'),
            kwargs.get('table_max_time'),
            kwargs.get('device')
151
152
        )
    else:
153
        logger.error('model name not allow')
154
155
156
        exit(1)

    if atom_model is None:
157
        logger.error('model init failed')
158
159
160
        exit(1)
    else:
        return atom_model