doc_analyze_by_custom_model.py 9.94 KB
Newer Older
1
import os
2
import time
icecraft's avatar
icecraft committed
3

icecraft's avatar
icecraft committed
4
import numpy as np
icecraft's avatar
icecraft committed
5
6
import torch

7
8
9
10
os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
os.environ['FLAGS_use_stride_kernel'] = '0'
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
icecraft's avatar
icecraft committed
11

12

icecraft's avatar
icecraft committed
13
14
from loguru import logger

15
from magic_pdf.model.sub_modules.model_utils import get_vram
16
from magic_pdf.config.enums import SupportedPdfParseMethod
17
18
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
19
from magic_pdf.libs.clean_memory import clean_memory
20
21
22
23
from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                          get_layout_config,
                                          get_local_models_dir,
                                          get_table_recog_config)
24
from magic_pdf.model.model_list import MODEL
icecraft's avatar
icecraft committed
25

26
27
28
29
30
31
32
33
34
class ModelSingleton:
    _instance = None
    _models = {}

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

35
36
37
38
39
40
41
42
43
    def get_model(
        self,
        ocr: bool,
        show_log: bool,
        lang=None,
        layout_model=None,
        formula_enable=None,
        table_enable=None,
    ):
44
        key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
45
        if key not in self._models:
46
47
48
49
50
51
52
53
            self._models[key] = custom_model_init(
                ocr=ocr,
                show_log=show_log,
                lang=lang,
                layout_model=layout_model,
                formula_enable=formula_enable,
                table_enable=table_enable,
            )
54
55
56
        return self._models[key]


57
58
59
60
61
62
63
64
def custom_model_init(
    ocr: bool = False,
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
65
    model = None
66
67
68
69
70
    if model_config.__model_mode__ == 'lite':
        logger.warning(
            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
            'not guaranteed to be reliable.'
        )
71
        model = MODEL.Paddle
72
    elif model_config.__model_mode__ == 'full':
73
74
        model = MODEL.PEK

75
    if model_config.__use_inside_model__:
76
        model_init_start = time.time()
77
78
        if model == MODEL.Paddle:
            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
79

80
            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
81
82
        elif model == MODEL.PEK:
            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
83

84
85
86
            # 从配置文件读取model-dir和device
            local_models_dir = get_local_models_dir()
            device = get_device()
87
88
89

            layout_config = get_layout_config()
            if layout_model is not None:
90
                layout_config['model'] = layout_model
91
92
93

            formula_config = get_formula_config()
            if formula_enable is not None:
94
                formula_config['enable'] = formula_enable
95

96
            table_config = get_table_recog_config()
97
            if table_enable is not None:
98
                table_config['enable'] = table_enable
99
100

            model_input = {
101
102
103
104
105
106
107
108
                'ocr': ocr,
                'show_log': show_log,
                'models_dir': local_models_dir,
                'device': device,
                'table_config': table_config,
                'layout_config': layout_config,
                'formula_config': formula_config,
                'lang': lang,
109
110
            }

111
            custom_model = CustomPEKModel(**model_input)
112
        else:
113
            logger.error('Not allow model_name!')
114
            exit(1)
115
        model_init_cost = time.time() - model_init_start
116
        logger.info(f'model init cost: {model_init_cost}')
117
    else:
118
        logger.error('use_inside_model is False, not allow to use inside model')
119
120
        exit(1)

121
122
    return custom_model

123
124
125
126
127
128
129
130
131
132
def doc_analyze(
    dataset: Dataset,
    ocr: bool = False,
    show_log: bool = False,
    start_page_id=0,
    end_page_id=None,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
icecraft's avatar
icecraft committed
133
):
134
135
136
137
138
    end_page_id = (
        end_page_id
        if end_page_id is not None and end_page_id >= 0
        else len(dataset) - 1
    )
icecraft's avatar
icecraft committed
139

140
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
icecraft's avatar
icecraft committed
141
142
143
144
145
146
147
148
    images = []
    page_wh_list = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
icecraft's avatar
icecraft committed
149
150
151
152
    if lang is None or lang == 'auto':
        images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
    else:
        images_with_extra_info = [(images[index], ocr, lang) for index in range(len(dataset))]
icecraft's avatar
icecraft committed
153

icecraft's avatar
icecraft committed
154
155
    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
        batch_size = MIN_BATCH_INFERENCE_SIZE
156
        batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
icecraft's avatar
icecraft committed
157
    else:
158
        batch_images = [images_with_extra_info]
icecraft's avatar
icecraft committed
159
160
161

    results = []
    for sn, batch_image in enumerate(batch_images):
icecraft's avatar
icecraft committed
162
        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
icecraft's avatar
icecraft committed
163
        results.extend(result)
icecraft's avatar
icecraft committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

    model_json = []
    for index in range(len(dataset)):
        if start_page_id <= index <= end_page_id:
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
        else:
            result = []
            page_height = 0
            page_width = 0

        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
        page_dict = {'layout_dets': result, 'page_info': page_info}
        model_json.append(page_dict)

    from magic_pdf.operators.models import InferenceResult
    return InferenceResult(model_json, dataset)

def batch_doc_analyze(
    datasets: list[Dataset],
184
    parse_method: str,
icecraft's avatar
icecraft committed
185
186
187
188
189
190
    show_log: bool = False,
    lang=None,
    layout_model=None,
    formula_enable=None,
    table_enable=None,
):
icecraft's avatar
icecraft committed
191
    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
192
    batch_size = MIN_BATCH_INFERENCE_SIZE
icecraft's avatar
icecraft committed
193
194
    images = []
    page_wh_list = []
195
196

    images_with_extra_info = []
icecraft's avatar
icecraft committed
197
198
    for dataset in datasets:
        for index in range(len(dataset)):
199
            if lang is None or lang == 'auto':
200
                _lang = dataset._lang
201
            else:
202
203
                _lang = lang

icecraft's avatar
icecraft committed
204
205
206
207
            page_data = dataset.get_page(index)
            img_dict = page_data.get_image()
            images.append(img_dict['img'])
            page_wh_list.append((img_dict['width'], img_dict['height']))
208
209
210
211
            if parse_method == 'auto':
                images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
            else:
                images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
icecraft's avatar
icecraft committed
212

213
214
215
    batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
    results = []
    for sn, batch_image in enumerate(batch_images):
icecraft's avatar
icecraft committed
216
        _, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
217
        results.extend(result)
icecraft's avatar
icecraft committed
218

icecraft's avatar
icecraft committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    infer_results = []
    from magic_pdf.operators.models import InferenceResult
    for index in range(len(datasets)):
        dataset = datasets[index]
        model_json = []
        for i in range(len(dataset)):
            result = results.pop(0)
            page_width, page_height = page_wh_list.pop(0)
            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
            page_dict = {'layout_dets': result, 'page_info': page_info}
            model_json.append(page_dict)
        infer_results.append(InferenceResult(model_json, dataset))
    return infer_results


def may_batch_image_analyze(
235
        images_with_extra_info: list[(np.ndarray, bool, str)],
icecraft's avatar
icecraft committed
236
        idx: int,
237
        ocr: bool,
icecraft's avatar
icecraft committed
238
239
240
        show_log: bool = False,
        layout_model=None,
        formula_enable=None,
icecraft's avatar
icecraft committed
241
242
        table_enable=None):
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
243

icecraft's avatar
icecraft committed
244
    from magic_pdf.model.batch_analyze import BatchAnalyze
icecraft's avatar
icecraft committed
245

246
    model_manager = ModelSingleton()
247

248
    images = [image for image, _, _ in images_with_extra_info]
249
    batch_analyze = False
250
    batch_ratio = 1
251
252
    device = get_device()

icecraft's avatar
icecraft committed
253
    if str(device).startswith('npu'):
254
255
        import torch_npu
        if torch_npu.npu.is_available():
256
            torch.npu.set_compile_mode(jit_compile=False)
257

258
    if str(device).startswith('npu') or str(device).startswith('cuda'):
icecraft's avatar
icecraft committed
259
        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
260
        if gpu_memory is not None:
261
            if gpu_memory >= 16:
262
                batch_ratio = 16
263
            elif gpu_memory >= 12:
264
                batch_ratio = 8
265
            elif gpu_memory >= 8:
266
                batch_ratio = 4
267
            elif gpu_memory >= 6:
268
                batch_ratio = 2
269
270
            else:
                batch_ratio = 1
271
            logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
icecraft's avatar
icecraft committed
272
            # batch_analyze = True
273
    elif str(device).startswith('mps'):
icecraft's avatar
icecraft committed
274
275
        # batch_analyze = True
        pass
276

icecraft's avatar
icecraft committed
277
    doc_analyze_start = time.time()
278

icecraft's avatar
icecraft committed
279
280
    batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
    results = batch_model(images_with_extra_info)
281

282
    gc_start = time.time()
283
    clean_memory(get_device())
284
    gc_time = round(time.time() - gc_start, 2)
285
    logger.info(f'gc time: {gc_time}')
286

287
    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
icecraft's avatar
icecraft committed
288
    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
289
290
291
292
    logger.info(
        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
        f' speed: {doc_analyze_speed} pages/second'
    )
icecraft's avatar
icecraft committed
293
    return (idx, results)