doc_analyze_by_custom_model.py 14.2 KB
Newer Older
1
import os
2
import time
3
import torch
icecraft's avatar
icecraft committed
4
5
6
import numpy as np
import multiprocessing as mp
import concurrent.futures as fut
7
8
9
10
os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
icecraft's avatar
icecraft committed
11

12

icecraft's avatar
icecraft committed
13
14
from loguru import logger

15
16
from magic_pdf.model.sub_modules.model_utils import get_vram

17
18
19
20
21
22
23
try:
    import torchtext
    if torchtext.__version__ >= '0.18.0':
        torchtext.disable_torchtext_deprecation_warning()
except ImportError:
    pass

24
25
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
26
from magic_pdf.libs.clean_memory import clean_memory
27
28
29
30
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
31
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
32
# from magic_pdf.operators.models import InferenceResult
赵小蒙's avatar
赵小蒙 committed
33

icecraft's avatar
icecraft committed
34
MIN_BATCH_INFERENCE_SIZE = 100
赵小蒙's avatar
赵小蒙 committed
35

36
37
38
39
40
41
42
43
44
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

45
46
47
48
49
50
51
52
53
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
54
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
55
        if key not in self._models:
56
57
58
59
60
61
62
63
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
64
65
66
        return self._models[key]


67
68
69
70
71
72
73
74
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
75
    model = None
76
77
78
79
80
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
81
        model = MODEL.Paddle
82
    elif model_config.__model_mode__ == 'full':
83
84
        model = MODEL.PEK

85
    if model_config.__use_inside_model__:
86
        model_init_start = time.time()
87
88
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
89

90
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
91
92
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
93

94
95
96
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
97
98
99

            layout_config = get_layout_config()
            if layout_model is not None:
100
                layout_config['model'] = layout_model
101
102
103

            formula_config = get_formula_config()
            if formula_enable is not None:
104
                formula_config['enable'] = formula_enable
105

106
            table_config = get_table_recog_config()
107
            if table_enable is not None:
108
                table_config['enable'] = table_enable
109
110

            model_input = {
111
112
113
114
115
116
117
118
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
119
120
            }

121
            custom_model = CustomPEKModel(**model_input)
122
        else:
123
            logger.error('Not allow model_name!')
124
            exit(1)
125
        model_init_cost = time.time() - model_init_start
126
        logger.info(f'model init cost: {model_init_cost}')
127
    else:
128
        logger.error('use_inside_model is False, not allow to use inside model')
129
130
        exit(1)

131
132
    return custom_model

133
134
135
136
137
138
139
140
141
142
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
icecraft's avatar
icecraft committed
143
144
    one_shot: bool = True,
):
145
146
147
148
149
    end_page_id = (
        end_page_id
        if end_page_id is not None and end_page_id >= 0
        else len(dataset) - 1
    )
icecraft's avatar
icecraft committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    parallel_count = None
    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])

    images = []
    page_wh_list = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))

    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
        if parallel_count is None:
            parallel_count = 2 # should check the gpu memory firstly !
        # split images into parallel_count batches
        if parallel_count > 1:
            batch_size = (len(images) + parallel_count - 1) // parallel_count
            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
        else:
            batch_images = [images]
        results = []
        parallel_count = len(batch_images) # adjust to real parallel count 
        # using concurrent.futures to analyze
        """
        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
            for future in fut.as_completed(futures):
                sn, result = future.result()
                result_history[sn] = result

        for key in sorted(result_history.keys()):
            results.extend(result_history[key])
        """
        results = []
        pool = mp.Pool(processes=parallel_count)
        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
        for sn, result in mapped_results:
            results.extend(result)
190

icecraft's avatar
icecraft committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    else:
        _, results = may_batch_image_analyze(
            images,
            0,
            ocr, 
            show_log, 
            lang, layout_model, formula_enable, table_enable)

    model_json = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
        else:
            result = []
            page_height = 0
            page_width = 0

        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
        page_dict = {'layout_dets': result, 'page_info': page_info}
        model_json.append(page_dict)

    from magic_pdf.operators.models import InferenceResult
    return InferenceResult(model_json, dataset)

def batch_doc_analyze(
    datasets: list[Dataset],
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
    one_shot: bool = True,
):
    parallel_count = None
    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
    images = []
    page_wh_list = []
    for dataset in datasets:
        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
    
    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
        if parallel_count is None:
            parallel_count = 2 # should check the gpu memory firstly !
        # split images into parallel_count batches
        if parallel_count > 1:
            batch_size = (len(images) + parallel_count - 1) // parallel_count
            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
        else:
            batch_images = [images]
        results = []
        parallel_count = len(batch_images) # adjust to real parallel count 
        # using concurrent.futures to analyze
        """
        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
            for future in fut.as_completed(futures):
                sn, result = future.result()
                result_history[sn] = result

        for key in sorted(result_history.keys()):
            results.extend(result_history[key])
        """
        results = []
        pool = mp.Pool(processes=parallel_count)
        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
        for sn, result in mapped_results:
            results.extend(result)
    else:
        _, results = may_batch_image_analyze(
            images,
            0,
            ocr, 
            show_log, 
            lang, layout_model, formula_enable, table_enable)
    infer_results = []

    from magic_pdf.operators.models import InferenceResult
    for index in range(len(datasets)):
        dataset = datasets[index]
        model_json = []
        for i in range(len(dataset)):
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
        infer_results.append(InferenceResult(model_json, dataset))
    return infer_results


def may_batch_image_analyze(
        images: list[np.ndarray], 
        idx: int,
        ocr: bool = False, 
        show_log: bool = False, 
        lang=None, 
        layout_model=None, 
        formula_enable=None, 
        table_enable=None):
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
    # 关闭paddle的信号处理
    import paddle
    paddle.disable_signal_handler()
    from magic_pdf.model.batch_analyze import BatchAnalyze
    
303
    model_manager = ModelSingleton()
304
305
306
    custom_model = model_manager.get_model(
        ocr, show_log, lang, layout_model, formula_enable, table_enable
    )
307

308
    batch_analyze = False
309
    batch_ratio = 1
310
311
312
313
314
315
316
    device = get_device()

    npu_support = False
    if str(device).startswith("npu"):
        import torch_npu
        if torch_npu.npu.is_available():
            npu_support = True
317
            torch.npu.set_compile_mode(jit_compile=False)
318
319

    if torch.cuda.is_available() and device != 'cpu' or npu_support:
320
        gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
321
        if gpu_memory is not None and gpu_memory >= 8:
322
323
324
            if gpu_memory >= 20:
                batch_ratio = 16
            elif gpu_memory >= 15:
325
326
                batch_ratio = 8
            elif gpu_memory >= 10:
327
                batch_ratio = 4
328
            else:
329
                batch_ratio = 2
330

331
332
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
            batch_analyze = True
333
    doc_analyze_start = time.time()
334

335
    if batch_analyze:
icecraft's avatar
icecraft committed
336
        """# batch analyze
337
        images = []
338
        page_wh_list = []
339
340
341
342
343
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                page_data = dataset.get_page(index)
                img_dict = page_data.get_image()
                images.append(img_dict['img'])
344
                page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
345
        """
346
        batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
icecraft's avatar
icecraft committed
347
348
        results = batch_model(images)
        """
349
350
351
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                result = analyze_result.pop(0)
352
                page_width, page_height = page_wh_list.pop(0)
353
354
            else:
                result = []
355
356
                page_height = 0
                page_width = 0
357

358
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
359
360
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
361
        """
362
363
    else:
        # single analyze
icecraft's avatar
icecraft committed
364
        """
365
366
367
368
369
370
371
372
373
374
375
376
377
        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            img = img_dict['img']
            page_width = img_dict['width']
            page_height = img_dict['height']
            if start_page_id <= index <= end_page_id:
                page_start = time.time()
                result = custom_model(img)
                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
            else:
                result = []

378
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
379
380
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
381
382
383
384
385
386
387
        """
        results = []
        for img_idx, img in enumerate(images):
            inference_start = time.time()
            result = custom_model(img)
            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
            results.append(result)
388

389
    gc_start = time.time()
390
    clean_memory(get_device())
391
    gc_time = round(time.time() - gc_start, 2)
392
    logger.info(f'gc time: {gc_time}')
393

394
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
icecraft's avatar
icecraft committed
395
    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
396
397
398
399
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
icecraft's avatar
icecraft committed
400
    return (idx, results)
赵小蒙's avatar
update:  
赵小蒙 committed
401