doc_analyze_by_custom_model.py 11.7 KB
Newer Older
icecraft's avatar
icecraft committed
1
2
import concurrent.futures as fut
import multiprocessing as mp
3
import os
4
import time
icecraft's avatar
icecraft committed
5

icecraft's avatar
icecraft committed
6
import numpy as np
icecraft's avatar
icecraft committed
7
8
import torch

9
10
11
12
os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
icecraft's avatar
icecraft committed
13

14

icecraft's avatar
icecraft committed
15
16
from loguru import logger

17
18
from magic_pdf.model.sub_modules.model_utils import get_vram

19
20
21
22
23
24
25
try:
    import torchtext
    if torchtext.__version__ >= '0.18.0':
        torchtext.disable_torchtext_deprecation_warning()
except ImportError:
    pass

26
27
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
28
from magic_pdf.libs.clean_memory import clean_memory
29
30
31
32
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
33
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
34

icecraft's avatar
icecraft committed
35
# from magic_pdf.operators.models import InferenceResult
赵小蒙's avatar
赵小蒙 committed
36

37
38
39
40
41
42
43
44
45
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

46
47
48
49
50
51
52
53
54
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
55
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
56
        if key not in self._models:
57
58
59
60
61
62
63
64
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
65
66
67
        return self._models[key]


68
69
70
71
72
73
74
75
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
76
    model = None
77
78
79
80
81
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
82
        model = MODEL.Paddle
83
    elif model_config.__model_mode__ == 'full':
84
85
        model = MODEL.PEK

86
    if model_config.__use_inside_model__:
87
        model_init_start = time.time()
88
89
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
90

91
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
92
93
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
94

95
96
97
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
98
99
100

            layout_config = get_layout_config()
            if layout_model is not None:
101
                layout_config['model'] = layout_model
102
103
104

            formula_config = get_formula_config()
            if formula_enable is not None:
105
                formula_config['enable'] = formula_enable
106

107
            table_config = get_table_recog_config()
108
            if table_enable is not None:
109
                table_config['enable'] = table_enable
110
111

            model_input = {
112
113
114
115
116
117
118
119
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
120
121
            }

122
            custom_model = CustomPEKModel(**model_input)
123
        else:
124
            logger.error('Not allow model_name!')
125
            exit(1)
126
        model_init_cost = time.time() - model_init_start
127
        logger.info(f'model init cost: {model_init_cost}')
128
    else:
129
        logger.error('use_inside_model is False, not allow to use inside model')
130
131
        exit(1)

132
133
    return custom_model

134
135
136
137
138
139
140
141
142
143
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
icecraft's avatar
icecraft committed
144
):
145
146
147
148
149
    end_page_id = (
        end_page_id
        if end_page_id is not None and end_page_id >= 0
        else len(dataset) - 1
    )
icecraft's avatar
icecraft committed
150

icecraft's avatar
icecraft committed
151
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
icecraft's avatar
icecraft committed
152
153
154
155
156
157
158
159
160
    images = []
    page_wh_list = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))

icecraft's avatar
icecraft committed
161
162
163
    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
        batch_size = MIN_BATCH_INFERENCE_SIZE
        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
icecraft's avatar
icecraft committed
164
    else:
icecraft's avatar
icecraft committed
165
166
167
168
169
170
        batch_images = [images]

    results = []
    for sn, batch_image in enumerate(batch_images):
        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
        results.extend(result)
icecraft's avatar
icecraft committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    model_json = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
        else:
            result = []
            page_height = 0
            page_width = 0

        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
        page_dict = {'layout_dets': result, 'page_info': page_info}
        model_json.append(page_dict)

    from magic_pdf.operators.models import InferenceResult
    return InferenceResult(model_json, dataset)

def batch_doc_analyze(
    datasets: list[Dataset],
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
icecraft's avatar
icecraft committed
198
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
icecraft's avatar
icecraft committed
199
200
201
202
203
204
205
206
    images = []
    page_wh_list = []
    for dataset in datasets:
        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
207

icecraft's avatar
icecraft committed
208
209
210
    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
        batch_size = MIN_BATCH_INFERENCE_SIZE
        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
icecraft's avatar
icecraft committed
211
    else:
icecraft's avatar
icecraft committed
212
213
214
215
216
217
218
        batch_images = [images]

    results = []

    for sn, batch_image in enumerate(batch_images):
        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
        results.extend(result)
icecraft's avatar
icecraft committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    infer_results = []

    from magic_pdf.operators.models import InferenceResult
    for index in range(len(datasets)):
        dataset = datasets[index]
        model_json = []
        for i in range(len(dataset)):
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
        infer_results.append(InferenceResult(model_json, dataset))
    return infer_results


def may_batch_image_analyze(
icecraft's avatar
icecraft committed
236
        images: list[np.ndarray],
icecraft's avatar
icecraft committed
237
        idx: int,
icecraft's avatar
icecraft committed
238
239
240
241
242
        ocr: bool = False,
        show_log: bool = False,
        lang=None,
        layout_model=None,
        formula_enable=None,
icecraft's avatar
icecraft committed
243
244
245
246
247
248
        table_enable=None):
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
    # 关闭paddle的信号处理
    import paddle
    paddle.disable_signal_handler()
    from magic_pdf.model.batch_analyze import BatchAnalyze
icecraft's avatar
icecraft committed
249

250
    model_manager = ModelSingleton()
251
252
253
    custom_model = model_manager.get_model(
        ocr, show_log, lang, layout_model, formula_enable, table_enable
    )
254

255
    batch_analyze = False
256
    batch_ratio = 1
257
258
    device = get_device()

icecraft's avatar
icecraft committed
259
    if str(device).startswith('npu'):
260
261
        import torch_npu
        if torch_npu.npu.is_available():
262
            torch.npu.set_compile_mode(jit_compile=False)
263

264
    if str(device).startswith('npu') or str(device).startswith('cuda'):
icecraft's avatar
icecraft committed
265
        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
266
        if gpu_memory is not None:
267
268
269
            if gpu_memory >= 20:
                batch_ratio = 16
            elif gpu_memory >= 15:
270
271
                batch_ratio = 8
            elif gpu_memory >= 10:
272
                batch_ratio = 4
273
            elif gpu_memory >= 7:
274
                batch_ratio = 2
275
276
            else:
                batch_ratio = 1
277
278
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
            batch_analyze = True
279
280
    elif str(device).startswith('mps'):
        batch_analyze = True
281
    doc_analyze_start = time.time()
282

283
    if batch_analyze:
icecraft's avatar
icecraft committed
284
        """# batch analyze
285
        images = []
286
        page_wh_list = []
287
288
289
290
291
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                page_data = dataset.get_page(index)
                img_dict = page_data.get_image()
                images.append(img_dict['img'])
292
                page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
293
        """
294
        batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
icecraft's avatar
icecraft committed
295
296
        results = batch_model(images)
        """
297
298
299
        for index in range(len(dataset)):
            if start_page_id <= index <= end_page_id:
                result = analyze_result.pop(0)
300
                page_width, page_height = page_wh_list.pop(0)
301
302
            else:
                result = []
303
304
                page_height = 0
                page_width = 0
305

306
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
307
308
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
309
        """
310
311
    else:
        # single analyze
icecraft's avatar
icecraft committed
312
        """
313
314
315
316
317
318
319
320
321
322
323
324
325
        for index in range(len(dataset)):
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            img = img_dict['img']
            page_width = img_dict['width']
            page_height = img_dict['height']
            if start_page_id <= index <= end_page_id:
                page_start = time.time()
                result = custom_model(img)
                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
            else:
                result = []

326
            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
327
328
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
icecraft's avatar
icecraft committed
329
330
331
332
333
334
335
        """
        results = []
        for img_idx, img in enumerate(images):
            inference_start = time.time()
            result = custom_model(img)
            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
            results.append(result)
336

337
    gc_start = time.time()
338
    clean_memory(get_device())
339
    gc_time = round(time.time() - gc_start, 2)
340
    logger.info(f'gc time: {gc_time}')
341

342
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
icecraft's avatar
icecraft committed
343
    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
344
345
346
347
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
icecraft's avatar
icecraft committed
348
    return (idx, results)