doc_analyze_by_custom_model.py 7.82 KB
Newer Older
1
import os
2
import time
3
4
import torch

5
6
7
8
os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
9
10
# 关闭paddle的信号处理
import paddle
11
12
paddle.disable_signal_handler()

icecraft's avatar
icecraft committed
13
14
from loguru import logger

15
16
17
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram

18
19
20
21
22
23
24
try:
    import torchtext
    if torchtext.__version__ >= '0.18.0':
        torchtext.disable_torchtext_deprecation_warning()
except ImportError:
    pass

25
26
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
27
from magic_pdf.libs.clean_memory import clean_memory
28
29
30
31
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
32
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
33
from magic_pdf.operators.models import InferenceResult
赵小蒙's avatar
赵小蒙 committed
34
35


36
37
38
39
40
41
42
43
44
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

45
46
47
48
49
50
51
52
53
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
54
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
55
        if key not in self._models:
56
57
58
59
60
61
62
63
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
64
65
66
        return self._models[key]


67
68
69
70
71
72
73
74
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
75

76
77
    model = None

78
79
80
81
82
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
83
        model = MODEL.Paddle
84
    elif model_config.__model_mode__ == 'full':
85
86
        model = MODEL.PEK

87
    if model_config.__use_inside_model__:
88
        model_init_start = time.time()
89
90
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
91

92
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
93
94
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
95

96
97
98
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
99
100
101

            layout_config = get_layout_config()
            if layout_model is not None:
102
                layout_config['model'] = layout_model
103
104
105

            formula_config = get_formula_config()
            if formula_enable is not None:
106
                formula_config['enable'] = formula_enable
107

108
            table_config = get_table_recog_config()
109
            if table_enable is not None:
110
                table_config['enable'] = table_enable
111
112

            model_input = {
113
114
115
116
117
118
119
120
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
121
122
            }

123
            custom_model = CustomPEKModel(**model_input)
124
        else:
125
            logger.error('Not allow model_name!')
126
            exit(1)
127
        model_init_cost = time.time() - model_init_start
128
        logger.info(f'model init cost: {model_init_cost}')
129
    else:
130
        logger.error('use_inside_model is False, not allow to use inside model')
131
132
        exit(1)

133
134
135
    return custom_model


136
137
138
139
140
141
142
143
144
145
146
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
) -> InferenceResult:
147

148
149
150
151
152
    end_page_id = (
        end_page_id
        if end_page_id is not None and end_page_id >= 0
        else len(dataset) - 1
    )
153

154
    model_manager = ModelSingleton()
155
156
157
    custom_model = model_manager.get_model(
        ocr, show_log, lang, layout_model, formula_enable, table_enable
    )
158

159
160
161
162
163
164
165
166
167
168
    batch_analyze = False
    device = get_device()

    npu_support = False
    if str(device).startswith("npu"):
        import torch_npu
        if torch_npu.npu.is_available():
            npu_support = True

    if torch.cuda.is_available() and device != 'cpu' or npu_support:
169
        gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
170
        if gpu_memory is not None and gpu_memory >= 8:
171

172
173
174
            if gpu_memory >= 40:
                batch_ratio = 32
            elif gpu_memory >=20:
175
                batch_ratio = 16
176
177
178
179
            elif gpu_memory >= 16:
                batch_ratio = 8
            elif gpu_memory >= 10:
                batch_ratio = 4
180
            else:
181
                batch_ratio = 2
182

183
184
185
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
            batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
            batch_analyze = True
186

187
    model_json = []
188
    doc_analyze_start = time.time()
189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    if batch_analyze:
        # batch analyze
        images = []
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                page_data = dataset.get_page(index)
                img_dict = page_data.get_image()
                images.append(img_dict['img'])
        analyze_result = batch_model(images)

        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            page_width = img_dict['width']
            page_height = img_dict['height']
            if start_page_id <= index <= end_page_id:
                result = analyze_result.pop(0)
            else:
                result = []

            page_info = {'page_no': index, 'height': page_height, 'width': page_width}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    else:
        # single analyze

        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            img = img_dict['img']
            page_width = img_dict['width']
            page_height = img_dict['height']
            if start_page_id <= index <= end_page_id:
                page_start = time.time()
                result = custom_model(img)
                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
            else:
                result = []

            page_info = {'page_no': index, 'height': page_height, 'width': page_width}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
233

234
    gc_start = time.time()
235
    clean_memory(get_device())
236
    gc_time = round(time.time() - gc_start, 2)
237
    logger.info(f'gc time: {gc_time}')
238

239
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
240
241
242
243
244
    doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
赵小蒙's avatar
update:  
赵小蒙 committed
245

246
    return InferenceResult(model_json, dataset)