doc_analyze_by_custom_model.py 12.2 KB
Newer Older
icecraft's avatar
icecraft committed
1
2
import concurrent.futures as fut
import multiprocessing as mp
3
import os
4
import time
icecraft's avatar
icecraft committed
5

icecraft's avatar
icecraft committed
6
import numpy as np
icecraft's avatar
icecraft committed
7
8
import torch

9
10
11
12
os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
icecraft's avatar
icecraft committed
13

14

icecraft's avatar
icecraft committed
15
16
from loguru import logger

17
from magic_pdf.model.sub_modules.model_utils import get_vram
18
from magic_pdf.config.enums import SupportedPdfParseMethod
19
20
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
21
from magic_pdf.libs.clean_memory import clean_memory
22
23
24
25
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
26
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
27

icecraft's avatar
icecraft committed
28
# from magic_pdf.operators.models import InferenceResult
赵小蒙's avatar
赵小蒙 committed
29

30
31
32
33
34
35
36
37
38
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

39
40
41
42
43
44
45
46
47
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
48
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
49
        if key not in self._models:
50
51
52
53
54
55
56
57
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
58
59
60
        return self._models[key]


61
62
63
64
65
66
67
68
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
69
    model = None
70
71
72
73
74
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
75
        model = MODEL.Paddle
76
    elif model_config.__model_mode__ == 'full':
77
78
        model = MODEL.PEK

79
    if model_config.__use_inside_model__:
80
        model_init_start = time.time()
81
82
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
83

84
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
85
86
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
87

88
89
90
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
91
92
93

            layout_config = get_layout_config()
            if layout_model is not None:
94
                layout_config['model'] = layout_model
95
96
97

            formula_config = get_formula_config()
            if formula_enable is not None:
98
                formula_config['enable'] = formula_enable
99

100
            table_config = get_table_recog_config()
101
            if table_enable is not None:
102
                table_config['enable'] = table_enable
103
104

            model_input = {
105
106
107
108
109
110
111
112
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
113
114
            }

115
            custom_model = CustomPEKModel(**model_input)
116
        else:
117
            logger.error('Not allow model_name!')
118
            exit(1)
119
        model_init_cost = time.time() - model_init_start
120
        logger.info(f'model init cost: {model_init_cost}')
121
    else:
122
        logger.error('use_inside_model is False, not allow to use inside model')
123
124
        exit(1)

125
126
    return custom_model

127
128
129
130
131
132
133
134
135
136
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
icecraft's avatar
icecraft committed
137
):
138
139
140
141
142
    end_page_id = (
        end_page_id
        if end_page_id is not None and end_page_id >= 0
        else len(dataset) - 1
    )
icecraft's avatar
icecraft committed
143

icecraft's avatar
icecraft committed
144
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
icecraft's avatar
icecraft committed
145
146
147
148
149
150
151
152
    images = []
    page_wh_list = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
153
    images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
icecraft's avatar
icecraft committed
154

icecraft's avatar
icecraft committed
155
156
    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
        batch_size = MIN_BATCH_INFERENCE_SIZE
157
        batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
icecraft's avatar
icecraft committed
158
    else:
159
        batch_images = [images_with_extra_info]
icecraft's avatar
icecraft committed
160
161
162
163
164

    results = []
    for sn, batch_image in enumerate(batch_images):
        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
        results.extend(result)
icecraft's avatar
icecraft committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    model_json = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
        else:
            result = []
            page_height = 0
            page_width = 0

        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
        page_dict = {'layout_dets': result, 'page_info': page_info}
        model_json.append(page_dict)

    from magic_pdf.operators.models import InferenceResult
    return InferenceResult(model_json, dataset)

def batch_doc_analyze(
    datasets: list[Dataset],
185
    parse_method: str,
icecraft's avatar
icecraft committed
186
187
188
189
190
191
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
icecraft's avatar
icecraft committed
192
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
193
    batch_size = MIN_BATCH_INFERENCE_SIZE
icecraft's avatar
icecraft committed
194
195
    images = []
    page_wh_list = []
196
197

    images_with_extra_info = []
icecraft's avatar
icecraft committed
198
199
    for dataset in datasets:
        for index in range(len(dataset)):
200
            if lang is None or lang == 'auto':
201
                _lang = dataset._lang
202
            else:
203
204
                _lang = lang

icecraft's avatar
icecraft committed
205
206
207
208
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
209
210
211
212
            if parse_method == 'auto':
                images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
            else:
                images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
icecraft's avatar
icecraft committed
213

214
215
216
217
218
    batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
    results = []
    for sn, batch_image in enumerate(batch_images):
        _, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable)
        results.extend(result)
icecraft's avatar
icecraft committed
219

icecraft's avatar
icecraft committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    infer_results = []
    from magic_pdf.operators.models import InferenceResult
    for index in range(len(datasets)):
        dataset = datasets[index]
        model_json = []
        for i in range(len(dataset)):
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
        infer_results.append(InferenceResult(model_json, dataset))
    return infer_results


def may_batch_image_analyze(
236
        images_with_extra_info: list[(np.ndarray, bool, str)],
icecraft's avatar
icecraft committed
237
        idx: int,
238
        ocr: bool,
icecraft's avatar
icecraft committed
239
240
241
242
        show_log: bool = False,
        lang=None,
        layout_model=None,
        formula_enable=None,
icecraft's avatar
icecraft committed
243
244
245
246
247
248
        table_enable=None):
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
    # 关闭paddle的信号处理
    import paddle
    paddle.disable_signal_handler()
    from magic_pdf.model.batch_analyze import BatchAnalyze
icecraft's avatar
icecraft committed
249

250
    model_manager = ModelSingleton()
251
252
253
    custom_model = model_manager.get_model(
        ocr, show_log, lang, layout_model, formula_enable, table_enable
    )
254

255
    images = [image for image, _, _ in images_with_extra_info]
256
    batch_analyze = False
257
    batch_ratio = 1
258
259
    device = get_device()

icecraft's avatar
icecraft committed
260
    if str(device).startswith('npu'):
261
262
        import torch_npu
        if torch_npu.npu.is_available():
263
            torch.npu.set_compile_mode(jit_compile=False)
264

265
    if str(device).startswith('npu') or str(device).startswith('cuda'):
icecraft's avatar
icecraft committed
266
        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
267
        if gpu_memory is not None:
268
            if gpu_memory >= 16:
269
                batch_ratio = 16
270
            elif gpu_memory >= 12:
271
                batch_ratio = 8
272
            elif gpu_memory >= 8:
273
                batch_ratio = 4
274
            elif gpu_memory >= 6:
275
                batch_ratio = 2
276
277
            else:
                batch_ratio = 1
278
279
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
            batch_analyze = True
280
281
    elif str(device).startswith('mps'):
        batch_analyze = True
282
    doc_analyze_start = time.time()
283

284
    if batch_analyze:
icecraft's avatar
icecraft committed
285
        """# batch analyze
286
        images = []
287
        page_wh_list = []
288
289
290
291
292
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                page_data = dataset.get_page(index)
                img_dict = page_data.get_image()
                images.append(img_dict['img'])
293
                page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
294
        """
295
296
        batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
        results = batch_model(images_with_extra_info)
icecraft's avatar
icecraft committed
297
        """
298
299
300
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                result = analyze_result.pop(0)
301
                page_width, page_height = page_wh_list.pop(0)
302
303
            else:
                result = []
304
305
                page_height = 0
                page_width = 0
306

307
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
308
309
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
310
        """
311
312
    else:
        # single analyze
icecraft's avatar
icecraft committed
313
        """
314
315
316
317
318
319
320
321
322
323
324
325
326
        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            img = img_dict['img']
            page_width = img_dict['width']
            page_height = img_dict['height']
            if start_page_id <= index <= end_page_id:
                page_start = time.time()
                result = custom_model(img)
                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
            else:
                result = []

327
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
328
329
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
330
331
332
333
334
335
336
        """
        results = []
        for img_idx, img in enumerate(images):
            inference_start = time.time()
            result = custom_model(img)
            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
            results.append(result)
337

338
    gc_start = time.time()
339
    clean_memory(get_device())
340
    gc_time = round(time.time() - gc_start, 2)
341
    logger.info(f'gc time: {gc_time}')
342

343
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
icecraft's avatar
icecraft committed
344
    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
345
346
347
348
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
icecraft's avatar
icecraft committed
349
    return (idx, results)