doc_analyze_by_custom_model.py 14.2 KB
Newer Older
icecraft's avatar
icecraft committed
1
2
import concurrent.futures as fut
import multiprocessing as mp
3
import os
4
import time
icecraft's avatar
icecraft committed
5

icecraft's avatar
icecraft committed
6
import numpy as np
icecraft's avatar
icecraft committed
7
8
import torch

9
10
11
12
os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
icecraft's avatar
icecraft committed
13

14

icecraft's avatar
icecraft committed
15
16
from loguru import logger

17
18
from magic_pdf.model.sub_modules.model_utils import get_vram

19
20
21
22
23
24
25
try:
    import torchtext
    if torchtext.__version__ >= '0.18.0':
        torchtext.disable_torchtext_deprecation_warning()
except ImportError:
    pass

26
27
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
28
from magic_pdf.libs.clean_memory import clean_memory
29
30
31
32
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
33
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
34

icecraft's avatar
icecraft committed
35
# from magic_pdf.operators.models import InferenceResult
赵小蒙's avatar
赵小蒙 committed
36

icecraft's avatar
icecraft committed
37
MIN_BATCH_INFERENCE_SIZE = 100
赵小蒙's avatar
赵小蒙 committed
38

39
40
41
42
43
44
45
46
47
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

48
49
50
51
52
53
54
55
56
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
57
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
58
        if key not in self._models:
59
60
61
62
63
64
65
66
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
67
68
69
        return self._models[key]


70
71
72
73
74
75
76
77
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
78
    model = None
79
80
81
82
83
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
84
        model = MODEL.Paddle
85
    elif model_config.__model_mode__ == 'full':
86
87
        model = MODEL.PEK

88
    if model_config.__use_inside_model__:
89
        model_init_start = time.time()
90
91
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
92

93
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
94
95
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
96

97
98
99
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
100
101
102

            layout_config = get_layout_config()
            if layout_model is not None:
103
                layout_config['model'] = layout_model
104
105
106

            formula_config = get_formula_config()
            if formula_enable is not None:
107
                formula_config['enable'] = formula_enable
108

109
            table_config = get_table_recog_config()
110
            if table_enable is not None:
111
                table_config['enable'] = table_enable
112
113

            model_input = {
114
115
116
117
118
119
120
121
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
122
123
            }

124
            custom_model = CustomPEKModel(**model_input)
125
        else:
126
            logger.error('Not allow model_name!')
127
            exit(1)
128
        model_init_cost = time.time() - model_init_start
129
        logger.info(f'model init cost: {model_init_cost}')
130
    else:
131
        logger.error('use_inside_model is False, not allow to use inside model')
132
133
        exit(1)

134
135
    return custom_model

136
137
138
139
140
141
142
143
144
145
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
icecraft's avatar
icecraft committed
146
147
    one_shot: bool = True,
):
148
149
150
151
152
    end_page_id = (
        end_page_id
        if end_page_id is not None and end_page_id >= 0
        else len(dataset) - 1
    )
icecraft's avatar
icecraft committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    parallel_count = None
    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])

    images = []
    page_wh_list = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))

    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
        if parallel_count is None:
            parallel_count = 2 # should check the gpu memory firstly !
        # split images into parallel_count batches
        if parallel_count > 1:
            batch_size = (len(images) + parallel_count - 1) // parallel_count
            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
        else:
            batch_images = [images]
        results = []
icecraft's avatar
icecraft committed
176
        parallel_count = len(batch_images) # adjust to real parallel count
icecraft's avatar
icecraft committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        # using concurrent.futures to analyze
        """
        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
            for future in fut.as_completed(futures):
                sn, result = future.result()
                result_history[sn] = result

        for key in sorted(result_history.keys()):
            results.extend(result_history[key])
        """
        results = []
        pool = mp.Pool(processes=parallel_count)
        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
        for sn, result in mapped_results:
            results.extend(result)
193

icecraft's avatar
icecraft committed
194
195
196
197
    else:
        _, results = may_batch_image_analyze(
            images,
            0,
icecraft's avatar
icecraft committed
198
199
            ocr,
            show_log,
icecraft's avatar
icecraft committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            lang, layout_model, formula_enable, table_enable)

    model_json = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
        else:
            result = []
            page_height = 0
            page_width = 0

        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
        page_dict = {'layout_dets': result, 'page_info': page_info}
        model_json.append(page_dict)

    from magic_pdf.operators.models import InferenceResult
    return InferenceResult(model_json, dataset)

def batch_doc_analyze(
    datasets: list[Dataset],
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
    one_shot: bool = True,
):
    parallel_count = None
    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
    images = []
    page_wh_list = []
    for dataset in datasets:
        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
240

icecraft's avatar
icecraft committed
241
242
243
244
245
246
247
248
249
250
    if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
        if parallel_count is None:
            parallel_count = 2 # should check the gpu memory firstly !
        # split images into parallel_count batches
        if parallel_count > 1:
            batch_size = (len(images) + parallel_count - 1) // parallel_count
            batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
        else:
            batch_images = [images]
        results = []
icecraft's avatar
icecraft committed
251
        parallel_count = len(batch_images) # adjust to real parallel count
icecraft's avatar
icecraft committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        # using concurrent.futures to analyze
        """
        with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
            futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
            for future in fut.as_completed(futures):
                sn, result = future.result()
                result_history[sn] = result

        for key in sorted(result_history.keys()):
            results.extend(result_history[key])
        """
        results = []
        pool = mp.Pool(processes=parallel_count)
        mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
        for sn, result in mapped_results:
            results.extend(result)
    else:
        _, results = may_batch_image_analyze(
            images,
            0,
icecraft's avatar
icecraft committed
272
273
            ocr,
            show_log,
icecraft's avatar
icecraft committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            lang, layout_model, formula_enable, table_enable)
    infer_results = []

    from magic_pdf.operators.models import InferenceResult
    for index in range(len(datasets)):
        dataset = datasets[index]
        model_json = []
        for i in range(len(dataset)):
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
        infer_results.append(InferenceResult(model_json, dataset))
    return infer_results


def may_batch_image_analyze(
icecraft's avatar
icecraft committed
292
        images: list[np.ndarray],
icecraft's avatar
icecraft committed
293
        idx: int,
icecraft's avatar
icecraft committed
294
295
296
297
298
        ocr: bool = False,
        show_log: bool = False,
        lang=None,
        layout_model=None,
        formula_enable=None,
icecraft's avatar
icecraft committed
299
300
301
302
303
304
        table_enable=None):
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
    # 关闭paddle的信号处理
    import paddle
    paddle.disable_signal_handler()
    from magic_pdf.model.batch_analyze import BatchAnalyze
icecraft's avatar
icecraft committed
305

306
    model_manager = ModelSingleton()
307
308
309
    custom_model = model_manager.get_model(
        ocr, show_log, lang, layout_model, formula_enable, table_enable
    )
310

311
    batch_analyze = False
312
    batch_ratio = 1
313
314
315
    device = get_device()

    npu_support = False
icecraft's avatar
icecraft committed
316
    if str(device).startswith('npu'):
317
318
319
        import torch_npu
        if torch_npu.npu.is_available():
            npu_support = True
320
            torch.npu.set_compile_mode(jit_compile=False)
321
322

    if torch.cuda.is_available() and device != 'cpu' or npu_support:
icecraft's avatar
icecraft committed
323
        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
324
        if gpu_memory is not None and gpu_memory >= 8:
325
326
327
            if gpu_memory >= 20:
                batch_ratio = 16
            elif gpu_memory >= 15:
328
329
                batch_ratio = 8
            elif gpu_memory >= 10:
330
                batch_ratio = 4
331
            else:
332
                batch_ratio = 2
333

334
335
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
            batch_analyze = True
336
    doc_analyze_start = time.time()
337

338
    if batch_analyze:
icecraft's avatar
icecraft committed
339
        """# batch analyze
340
        images = []
341
        page_wh_list = []
342
343
344
345
346
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                page_data = dataset.get_page(index)
                img_dict = page_data.get_image()
                images.append(img_dict['img'])
347
                page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
348
        """
349
        batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
icecraft's avatar
icecraft committed
350
351
        results = batch_model(images)
        """
352
353
354
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                result = analyze_result.pop(0)
355
                page_width, page_height = page_wh_list.pop(0)
356
357
            else:
                result = []
358
359
                page_height = 0
                page_width = 0
360

361
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
362
363
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
364
        """
365
366
    else:
        # single analyze
icecraft's avatar
icecraft committed
367
        """
368
369
370
371
372
373
374
375
376
377
378
379
380
        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            img = img_dict['img']
            page_width = img_dict['width']
            page_height = img_dict['height']
            if start_page_id <= index <= end_page_id:
                page_start = time.time()
                result = custom_model(img)
                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
            else:
                result = []

381
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
382
383
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
384
385
386
387
388
389
390
        """
        results = []
        for img_idx, img in enumerate(images):
            inference_start = time.time()
            result = custom_model(img)
            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
            results.append(result)
391

392
    gc_start = time.time()
393
    clean_memory(get_device())
394
    gc_time = round(time.time() - gc_start, 2)
395
    logger.info(f'gc time: {gc_time}')
396

397
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
icecraft's avatar
icecraft committed
398
    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
399
400
401
402
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
icecraft's avatar
icecraft committed
403
    return (idx, results)