doc_analyze_by_custom_model.py 12.3 KB
Newer Older
icecraft's avatar
icecraft committed
1
2
import concurrent.futures as fut
import multiprocessing as mp
3
import os
4
import time
icecraft's avatar
icecraft committed
5

icecraft's avatar
icecraft committed
6
import numpy as np
icecraft's avatar
icecraft committed
7
8
import torch

9
10
11
12
os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
icecraft's avatar
icecraft committed
13

14

icecraft's avatar
icecraft committed
15
16
from loguru import logger

17
18
from magic_pdf.model.sub_modules.model_utils import get_vram

19
20
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
21
from magic_pdf.libs.clean_memory import clean_memory
22
23
24
25
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
26
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
27

icecraft's avatar
icecraft committed
28
# from magic_pdf.operators.models import InferenceResult
赵小蒙's avatar
赵小蒙 committed
29

30
31
32
33
34
35
36
37
38
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

39
40
41
42
43
44
45
46
47
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
48
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
49
        if key not in self._models:
50
51
52
53
54
55
56
57
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
58
59
60
        return self._models[key]


61
62
63
64
65
66
67
68
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
69
    model = None
70
71
72
73
74
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
75
        model = MODEL.Paddle
76
    elif model_config.__model_mode__ == 'full':
77
78
        model = MODEL.PEK

79
    if model_config.__use_inside_model__:
80
        model_init_start = time.time()
81
82
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
83

84
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
85
86
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
87

88
89
90
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
91
92
93

            layout_config = get_layout_config()
            if layout_model is not None:
94
                layout_config['model'] = layout_model
95
96
97

            formula_config = get_formula_config()
            if formula_enable is not None:
98
                formula_config['enable'] = formula_enable
99

100
            table_config = get_table_recog_config()
101
            if table_enable is not None:
102
                table_config['enable'] = table_enable
103
104

            model_input = {
105
106
107
108
109
110
111
112
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
113
114
            }

115
            custom_model = CustomPEKModel(**model_input)
116
        else:
117
            logger.error('Not allow model_name!')
118
            exit(1)
119
        model_init_cost = time.time() - model_init_start
120
        logger.info(f'model init cost: {model_init_cost}')
121
    else:
122
        logger.error('use_inside_model is False, not allow to use inside model')
123
124
        exit(1)

125
126
    return custom_model

127
128
129
130
131
132
133
134
135
136
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
icecraft's avatar
icecraft committed
137
):
138
139
140
141
142
    end_page_id = (
        end_page_id
        if end_page_id is not None and end_page_id >= 0
        else len(dataset) - 1
    )
icecraft's avatar
icecraft committed
143

icecraft's avatar
icecraft committed
144
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
icecraft's avatar
icecraft committed
145
146
147
148
149
150
151
152
153
    images = []
    page_wh_list = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))

icecraft's avatar
icecraft committed
154
155
156
    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
        batch_size = MIN_BATCH_INFERENCE_SIZE
        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
icecraft's avatar
icecraft committed
157
    else:
icecraft's avatar
icecraft committed
158
159
160
161
162
163
        batch_images = [images]

    results = []
    for sn, batch_image in enumerate(batch_images):
        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
        results.extend(result)
icecraft's avatar
icecraft committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    model_json = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
        else:
            result = []
            page_height = 0
            page_width = 0

        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
        page_dict = {'layout_dets': result, 'page_info': page_info}
        model_json.append(page_dict)

    from magic_pdf.operators.models import InferenceResult
    return InferenceResult(model_json, dataset)

def batch_doc_analyze(
    datasets: list[Dataset],
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
icecraft's avatar
icecraft committed
191
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
192
    batch_size = MIN_BATCH_INFERENCE_SIZE
icecraft's avatar
icecraft committed
193
194
    images = []
    page_wh_list = []
195
196
    lang_list = []
    lang_s = set()
icecraft's avatar
icecraft committed
197
198
    for dataset in datasets:
        for index in range(len(dataset)):
199
200
201
202
203
            if lang is None or lang == 'auto':
                lang_list.append(dataset._lang)
            else:
                lang_list.append(lang)
            lang_s.add(lang_list[-1])
icecraft's avatar
icecraft committed
204
205
206
207
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
208

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    batch_images = []
    img_idx_list = []
    for t_lang in lang_s:
        tmp_img_idx_list = []
        for i, _lang in enumerate(lang_list):
            if _lang == t_lang:
                tmp_img_idx_list.append(i)
        img_idx_list.extend(tmp_img_idx_list)

        if batch_size >= len(tmp_img_idx_list):
            batch_images.append((t_lang, [images[j] for j in tmp_img_idx_list]))
        else:
            slices = [tmp_img_idx_list[k:k+batch_size] for k in range(0, len(tmp_img_idx_list), batch_size)]
            for arr in slices:
                batch_images.append((t_lang, [images[j] for j in arr]))
icecraft's avatar
icecraft committed
224

225
226
227
228
229
230
231
232
    unorder_results = []

    for sn, (_lang, batch_image) in enumerate(batch_images):
        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, _lang, layout_model, formula_enable, table_enable)
        unorder_results.extend(result)
    results = [None] * len(img_idx_list)
    for i, idx in enumerate(img_idx_list):
        results[idx] = unorder_results[i]
icecraft's avatar
icecraft committed
233

icecraft's avatar
icecraft committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    infer_results = []

    from magic_pdf.operators.models import InferenceResult
    for index in range(len(datasets)):
        dataset = datasets[index]
        model_json = []
        for i in range(len(dataset)):
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
        infer_results.append(InferenceResult(model_json, dataset))
    return infer_results


def may_batch_image_analyze(
icecraft's avatar
icecraft committed
251
        images: list[np.ndarray],
icecraft's avatar
icecraft committed
252
        idx: int,
icecraft's avatar
icecraft committed
253
254
255
256
257
        ocr: bool = False,
        show_log: bool = False,
        lang=None,
        layout_model=None,
        formula_enable=None,
icecraft's avatar
icecraft committed
258
259
260
261
262
263
        table_enable=None):
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
    # 关闭paddle的信号处理
    import paddle
    paddle.disable_signal_handler()
    from magic_pdf.model.batch_analyze import BatchAnalyze
icecraft's avatar
icecraft committed
264

265
    model_manager = ModelSingleton()
266
267
268
    custom_model = model_manager.get_model(
        ocr, show_log, lang, layout_model, formula_enable, table_enable
    )
269

270
    batch_analyze = False
271
    batch_ratio = 1
272
273
    device = get_device()

icecraft's avatar
icecraft committed
274
    if str(device).startswith('npu'):
275
276
        import torch_npu
        if torch_npu.npu.is_available():
277
            torch.npu.set_compile_mode(jit_compile=False)
278

279
    if str(device).startswith('npu') or str(device).startswith('cuda'):
icecraft's avatar
icecraft committed
280
        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
281
        if gpu_memory is not None:
282
            if gpu_memory >= 16:
283
                batch_ratio = 16
284
            elif gpu_memory >= 12:
285
                batch_ratio = 8
286
            elif gpu_memory >= 8:
287
                batch_ratio = 4
288
            elif gpu_memory >= 6:
289
                batch_ratio = 2
290
291
            else:
                batch_ratio = 1
292
293
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
            batch_analyze = True
294
295
    elif str(device).startswith('mps'):
        batch_analyze = True
296
    doc_analyze_start = time.time()
297

298
    if batch_analyze:
icecraft's avatar
icecraft committed
299
        """# batch analyze
300
        images = []
301
        page_wh_list = []
302
303
304
305
306
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                page_data = dataset.get_page(index)
                img_dict = page_data.get_image()
                images.append(img_dict['img'])
307
                page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
308
        """
309
        batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
icecraft's avatar
icecraft committed
310
311
        results = batch_model(images)
        """
312
313
314
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                result = analyze_result.pop(0)
315
                page_width, page_height = page_wh_list.pop(0)
316
317
            else:
                result = []
318
319
                page_height = 0
                page_width = 0
320

321
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
322
323
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
324
        """
325
326
    else:
        # single analyze
icecraft's avatar
icecraft committed
327
        """
328
329
330
331
332
333
334
335
336
337
338
339
340
        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            img = img_dict['img']
            page_width = img_dict['width']
            page_height = img_dict['height']
            if start_page_id <= index <= end_page_id:
                page_start = time.time()
                result = custom_model(img)
                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
            else:
                result = []

341
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
342
343
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
344
345
346
347
348
349
350
        """
        results = []
        for img_idx, img in enumerate(images):
            inference_start = time.time()
            result = custom_model(img)
            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
            results.append(result)
351

352
    gc_start = time.time()
353
    clean_memory(get_device())
354
    gc_time = round(time.time() - gc_start, 2)
355
    logger.info(f'gc time: {gc_time}')
356

357
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
icecraft's avatar
icecraft committed
358
    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
359
360
361
362
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
icecraft's avatar
icecraft committed
363
    return (idx, results)