batch_analyze.py 16.8 KB
Newer Older
1
2
3
import time
import cv2
from loguru import logger
4
from tqdm import tqdm
5
6
from collections import defaultdict
import numpy as np
7
8

from magic_pdf.config.constants import MODEL_NAME
9
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
10
from magic_pdf.model.sub_modules.model_utils import (
11
    clean_vram, crop_img, get_res_list_from_layout_res, get_coords_and_area)
12
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
icecraft's avatar
icecraft committed
13
    get_adjusted_mfdetrec_res, get_ocr_result_list)
14

15
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
16
17
18
19
20
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16


class BatchAnalyze:
21
    def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable, enable_ocr_det_batch=False):
22
        self.model_manager = model_manager
23
        self.batch_ratio = batch_ratio
24
25
26
27
        self.show_log = show_log
        self.layout_model = layout_model
        self.formula_enable = formula_enable
        self.table_enable = table_enable
28
        self.enable_ocr_det_batch = enable_ocr_det_batch
29
30
31
32

    def __call__(self, images_with_extra_info: list) -> list:
        if len(images_with_extra_info) == 0:
            return []
33

34
        images_layout_res = []
35
        layout_start_time = time.time()
36
37
38
        self.model = self.model_manager.get_model(
            ocr=True,
            show_log=self.show_log,
39
40
41
42
            lang=None,
            layout_model=self.layout_model,
            formula_enable=self.formula_enable,
            table_enable=self.table_enable,
43
        )
44
45
46

        images = [image for image, _, _ in images_with_extra_info]

47
48
49
50
51
52
53
        if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
            # layoutlmv3
            for image in images:
                layout_res = self.model.layout_model(image, ignore_catids=[])
                images_layout_res.append(layout_res)
        elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
            # doclayout_yolo
54
55
            layout_images = []
            for image_index, image in enumerate(images):
56
                layout_images.append(image)
57

58
            images_layout_res += self.model.layout_model.batch_predict(
59
60
                # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
                layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
61
62
            )

63
64
65
        # logger.info(
        #     f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
        # )
66

67
68
        if self.model.apply_formula:
            # 公式检测
69
            mfd_start_time = time.time()
70
            images_mfd_res = self.model.mfd_model.batch_predict(
71
72
                # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
                images, MFD_BASE_BATCH_SIZE
73
            )
74
75
76
            # logger.info(
            #     f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
            # )
77
78

            # 公式识别
79
            mfr_start_time = time.time()
80
81
82
83
84
            images_formula_list = self.model.mfr_model.batch_predict(
                images_mfd_res,
                images,
                batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
            )
85
            mfr_count = 0
86
87
            for image_index in range(len(images)):
                images_layout_res[image_index] += images_formula_list[image_index]
88
                mfr_count += len(images_formula_list[image_index])
89
90
91
            # logger.info(
            #     f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
            # )
92
93

        # 清理显存
94
        # clean_vram(self.model.device, vram_threshold=8)
95

96
97
        ocr_res_list_all_page = []
        table_res_list_all_page = []
98
        for index in range(len(images)):
99
            _, ocr_enable, _lang = images_with_extra_info[index]
100
            layout_res = images_layout_res[index]
101
            np_array_img = images[index]
102
103
104
105

            ocr_res_list, table_res_list, single_page_mfdetrec_res = (
                get_res_list_from_layout_res(layout_res)
            )
106

107
108
109
110
111
112
113
114
            ocr_res_list_all_page.append({
                'ocr_res_list': ocr_res_list,
                'lang': _lang,
                'ocr_enable': ocr_enable,
                'np_array_img': np_array_img,
                'single_page_mfdetrec_res': single_page_mfdetrec_res,
                'layout_res': layout_res,
            })
115
116
117

            for table_res in table_res_list:
                table_img, _ = crop_img(table_res, np_array_img)
118
119
120
121
122
123
124
125
126
127
128
129
                table_res_list_all_page.append({
                    'table_res': table_res,
                    'lang': _lang,
                    'table_img': table_img,
                })

        # OCR检测处理
        if self.enable_ocr_det_batch:
            # 批处理模式 - 按语言和分辨率分组
            # 收集所有需要OCR检测的裁剪图像
            all_cropped_images_info = []

Xiaomeng Zhao's avatar
Xiaomeng Zhao committed
130
            for ocr_res_list_dict in ocr_res_list_all_page:
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                _lang = ocr_res_list_dict['lang']

                for res in ocr_res_list_dict['ocr_res_list']:
                    new_image, useful_list = crop_img(
                        res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
                    )
                    adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                        ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                    )

                    # BGR转换
                    new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)

                    all_cropped_images_info.append((
                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
                    ))

            # 按语言分组
            lang_groups = defaultdict(list)
            for crop_info in all_cropped_images_info:
                lang = crop_info[5]
                lang_groups[lang].append(crop_info)

            # 对每种语言按分辨率分组并批处理
            for lang, lang_crop_list in lang_groups.items():
                if not lang_crop_list:
                    continue

Xiaomeng Zhao's avatar
Xiaomeng Zhao committed
159
                # logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
160
161
162
163
164
165
166
167

                # 获取OCR模型
                atom_model_manager = AtomModelSingleton()
                ocr_model = atom_model_manager.get_atom_model(
                    atom_model_name='ocr',
                    ocr_show_log=False,
                    det_db_box_thresh=0.3,
                    lang=lang
168
                )
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

                # 按分辨率分组并同时完成padding
                resolution_groups = defaultdict(list)
                for crop_info in lang_crop_list:
                    cropped_img = crop_info[0]
                    h, w = cropped_img.shape[:2]
                    # 使用更大的分组容差,减少分组数量
                    # 将尺寸标准化到32的倍数
                    normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
                    normalized_w = ((w + 32) // 32) * 32
                    group_key = (normalized_h, normalized_w)
                    resolution_groups[group_key].append(crop_info)

                # 对每个分辨率组进行批处理
                for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
                    raw_images = [crop_info[0] for crop_info in group_crops]

                    # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
                    max_h = max(img.shape[0] for img in raw_images)
                    max_w = max(img.shape[1] for img in raw_images)
                    target_h = ((max_h + 32 - 1) // 32) * 32
                    target_w = ((max_w + 32 - 1) // 32) * 32

                    # 对所有图像进行padding到统一尺寸
                    batch_images = []
                    for img in raw_images:
                        h, w = img.shape[:2]
                        # 创建目标尺寸的白色背景
                        padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
                        # 将原图像粘贴到左上角
                        padded_img[:h, :w] = img
                        batch_images.append(padded_img)

                    # 批处理检测
                    batch_size = min(len(batch_images), self.batch_ratio * 16)  # 增加批处理大小
Xiaomeng Zhao's avatar
Xiaomeng Zhao committed
204
                    # logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
                    batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)

                    # 处理批处理结果
                    for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info

                        if dt_boxes is not None:
                            # 构造OCR结果格式 - 每个box应该是4个点的列表
                            ocr_res = [box.tolist() for box in dt_boxes]

                            if ocr_res:
                                ocr_result_list = get_ocr_result_list(
                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
                                )

                                if res["category_id"] == 3:
                                    # ocr_result_list中所有bbox的面积之和
                                    ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
                                    # 求ocr_res_area和res的面积的比值
                                    res_area = get_coords_and_area(res)[4]
                                    if res_area > 0:
                                        ratio = ocr_res_area / res_area
                                        if ratio > 0.25:
                                            res["category_id"] = 1
                                        else:
                                            continue

                                ocr_res_list_dict['layout_res'].extend(ocr_result_list)
        else:
            # 原始单张处理模式
            for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
                # Process each area that requires OCR processing
                _lang = ocr_res_list_dict['lang']
                # Get OCR results for this language's images
                atom_model_manager = AtomModelSingleton()
                ocr_model = atom_model_manager.get_atom_model(
                    atom_model_name='ocr',
                    ocr_show_log=False,
                    det_db_box_thresh=0.3,
                    lang=_lang
245
                )
246
247
248
249
250
251
252
                for res in ocr_res_list_dict['ocr_res_list']:
                    new_image, useful_list = crop_img(
                        res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
                    )
                    adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                        ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                    )
253

254
                # OCR-det
255
                new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
256
                ocr_res = ocr_model.ocr(
257
258
                    new_image, mfd_res=adjusted_mfdetrec_res, rec=False
                )[0]
259
260
261

                # Integration results
                if ocr_res:
262
                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
263
264
265
266
267
268
269
270

                    if res["category_id"] == 3:
                        # ocr_result_list中所有bbox的面积之和
                        ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
                        # 求ocr_res_area和res的面积的比值
                        res_area = get_coords_and_area(res)[4]
                        if res_area > 0:
                            ratio = ocr_res_area / res_area
271
                            if ratio > 0.25:
272
273
274
275
                                res["category_id"] = 1
                            else:
                                continue

276
                    ocr_res_list_dict['layout_res'].extend(ocr_result_list)
277
278

            # det_count += len(ocr_res_list_dict['ocr_res_list'])
279
280
281
282
283
284
285
        # logger.info(f'ocr-det time: {round(time.time()-det_start, 2)}, image num: {det_count}')


        # 表格识别 table recognition
        if self.model.apply_table:
            table_start = time.time()
            # for table_res_list_dict in table_res_list_all_page:
286
287
            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
                _lang = table_res_dict['lang']
288
289
290
291
292
293
294
                atom_model_manager = AtomModelSingleton()
                table_model = atom_model_manager.get_atom_model(
                    atom_model_name='table',
                    table_model_name='rapid_table',
                    table_model_path='',
                    table_max_time=400,
                    device='cpu',
295
                    lang=_lang,
296
297
                    table_sub_model_name='slanet_plus'
                )
298
299
300
301
302
303
304
305
                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
                # 判断是否返回正常
                if html_code:
                    expected_ending = html_code.strip().endswith(
                        '</html>'
                    ) or html_code.strip().endswith('</table>')
                    if expected_ending:
                        table_res_dict['table_res']['html'] = html_code
306
307
                    else:
                        logger.warning(
308
                            'table recognition processing fails, not found expected HTML table end'
309
                        )
310
311
312
313
314
                else:
                    logger.warning(
                        'table recognition processing fails, not get html return'
                    )
            # logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
315

316
317
318
319
        # Create dictionaries to store items by language
        need_ocr_lists_by_lang = {}  # Dict of lists for each language
        img_crop_lists_by_lang = {}  # Dict of lists for each language

320
321
322
        for layout_res in images_layout_res:
            for layout_res_item in layout_res:
                if layout_res_item['category_id'] in [15]:
323
324
325
326
327
328
329
330
331
332
333
334
335
                    if 'np_img' in layout_res_item and 'lang' in layout_res_item:
                        lang = layout_res_item['lang']

                        # Initialize lists for this language if not exist
                        if lang not in need_ocr_lists_by_lang:
                            need_ocr_lists_by_lang[lang] = []
                            img_crop_lists_by_lang[lang] = []

                        # Add to the appropriate language-specific lists
                        need_ocr_lists_by_lang[lang].append(layout_res_item)
                        img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])

                        # Remove the fields after adding to lists
336
                        layout_res_item.pop('np_img')
337
338
339
340
341
342
343
344
345
346
347
348
349
350
                        layout_res_item.pop('lang')


        if len(img_crop_lists_by_lang) > 0:

            # Process OCR by language
            rec_time = 0
            rec_start = time.time()
            total_processed = 0

            # Process each language separately
            for lang, img_crop_list in img_crop_lists_by_lang.items():
                if len(img_crop_list) > 0:
                    # Get OCR results for this language's images
351
352
353
354
355
356
357
                    atom_model_manager = AtomModelSingleton()
                    ocr_model = atom_model_manager.get_atom_model(
                        atom_model_name='ocr',
                        ocr_show_log=False,
                        det_db_box_thresh=0.3,
                        lang=lang
                    )
358
                    ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
359
360
361

                    # Verify we have matching counts
                    assert len(ocr_res_list) == len(
362
                        need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
363
364

                    # Process OCR results for this language
365
                    for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
366
367
                        ocr_text, ocr_score = ocr_res_list[index]
                        layout_res_item['text'] = ocr_text
368
                        layout_res_item['score'] = float(f"{ocr_score:.3f}")
369
370

                    total_processed += len(img_crop_list)
371

372
            rec_time += time.time() - rec_start
373
            # logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
374
375
376



377
        return images_layout_res