batch_analyze.py 18.1 KB
Newer Older
1
2
3
import time
import cv2
from loguru import logger
4
from tqdm import tqdm
5
6
from collections import defaultdict
import numpy as np
7
8

from magic_pdf.config.constants import MODEL_NAME
9
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
10
from magic_pdf.model.sub_modules.model_utils import (
11
    clean_vram, crop_img, get_res_list_from_layout_res, get_coords_and_area)
12
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
icecraft's avatar
icecraft committed
13
    get_adjusted_mfdetrec_res, get_ocr_result_list)
14

15
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
16
17
18
19
20
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16


class BatchAnalyze:
21
    def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable, enable_ocr_det_batch=False):
22
        self.model_manager = model_manager
23
        self.batch_ratio = batch_ratio
24
25
26
27
        self.show_log = show_log
        self.layout_model = layout_model
        self.formula_enable = formula_enable
        self.table_enable = table_enable
28
        self.enable_ocr_det_batch = enable_ocr_det_batch
29
30
31
32

    def __call__(self, images_with_extra_info: list) -> list:
        if len(images_with_extra_info) == 0:
            return []
33

34
        images_layout_res = []
35
        layout_start_time = time.time()
36
37
38
        self.model = self.model_manager.get_model(
            ocr=True,
            show_log=self.show_log,
39
40
41
42
            lang=None,
            layout_model=self.layout_model,
            formula_enable=self.formula_enable,
            table_enable=self.table_enable,
43
        )
44
45
46

        images = [image for image, _, _ in images_with_extra_info]

47
48
49
50
51
52
53
        if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
            # layoutlmv3
            for image in images:
                layout_res = self.model.layout_model(image, ignore_catids=[])
                images_layout_res.append(layout_res)
        elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
            # doclayout_yolo
54
55
            layout_images = []
            for image_index, image in enumerate(images):
56
                layout_images.append(image)
57

58
            images_layout_res += self.model.layout_model.batch_predict(
59
60
                # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
                layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
61
62
            )

63
64
65
        # logger.info(
        #     f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
        # )
66

67
68
        if self.model.apply_formula:
            # 公式检测
69
            mfd_start_time = time.time()
70
            images_mfd_res = self.model.mfd_model.batch_predict(
71
72
                # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
                images, MFD_BASE_BATCH_SIZE
73
            )
74
75
76
            # logger.info(
            #     f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
            # )
77
78

            # 公式识别
79
            mfr_start_time = time.time()
80
81
82
83
84
            images_formula_list = self.model.mfr_model.batch_predict(
                images_mfd_res,
                images,
                batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
            )
85
            mfr_count = 0
86
87
            for image_index in range(len(images)):
                images_layout_res[image_index] += images_formula_list[image_index]
88
                mfr_count += len(images_formula_list[image_index])
89
90
91
            # logger.info(
            #     f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
            # )
92
93

        # 清理显存
94
        # clean_vram(self.model.device, vram_threshold=8)
95

96
97
        ocr_res_list_all_page = []
        table_res_list_all_page = []
98
        for index in range(len(images)):
99
            _, ocr_enable, _lang = images_with_extra_info[index]
100
            layout_res = images_layout_res[index]
101
            np_array_img = images[index]
102
103
104
105

            ocr_res_list, table_res_list, single_page_mfdetrec_res = (
                get_res_list_from_layout_res(layout_res)
            )
106

107
108
109
110
111
112
113
114
            ocr_res_list_all_page.append({
                'ocr_res_list': ocr_res_list,
                'lang': _lang,
                'ocr_enable': ocr_enable,
                'np_array_img': np_array_img,
                'single_page_mfdetrec_res': single_page_mfdetrec_res,
                'layout_res': layout_res,
            })
115
116
117

            for table_res in table_res_list:
                table_img, _ = crop_img(table_res, np_array_img)
118
119
120
121
122
123
124
125
126
127
128
129
                table_res_list_all_page.append({
                    'table_res': table_res,
                    'lang': _lang,
                    'table_img': table_img,
                })

        # OCR检测处理
        if self.enable_ocr_det_batch:
            # 批处理模式 - 按语言和分辨率分组
            # 收集所有需要OCR检测的裁剪图像
            all_cropped_images_info = []

Xiaomeng Zhao's avatar
Xiaomeng Zhao committed
130
            for ocr_res_list_dict in ocr_res_list_all_page:
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                _lang = ocr_res_list_dict['lang']

                for res in ocr_res_list_dict['ocr_res_list']:
                    new_image, useful_list = crop_img(
                        res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
                    )
                    adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                        ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                    )

                    # BGR转换
                    new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)

                    all_cropped_images_info.append((
                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
                    ))

            # 按语言分组
            lang_groups = defaultdict(list)
            for crop_info in all_cropped_images_info:
                lang = crop_info[5]
                lang_groups[lang].append(crop_info)

            # 对每种语言按分辨率分组并批处理
            for lang, lang_crop_list in lang_groups.items():
                if not lang_crop_list:
                    continue

Xiaomeng Zhao's avatar
Xiaomeng Zhao committed
159
                # logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
160
161
162
163
164
165
166
167

                # 获取OCR模型
                atom_model_manager = AtomModelSingleton()
                ocr_model = atom_model_manager.get_atom_model(
                    atom_model_name='ocr',
                    ocr_show_log=False,
                    det_db_box_thresh=0.3,
                    lang=lang
168
                )
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

                # 按分辨率分组并同时完成padding
                resolution_groups = defaultdict(list)
                for crop_info in lang_crop_list:
                    cropped_img = crop_info[0]
                    h, w = cropped_img.shape[:2]
                    # 使用更大的分组容差,减少分组数量
                    # 将尺寸标准化到32的倍数
                    normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
                    normalized_w = ((w + 32) // 32) * 32
                    group_key = (normalized_h, normalized_w)
                    resolution_groups[group_key].append(crop_info)

                # 对每个分辨率组进行批处理
                for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
                    # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
seedclaimer's avatar
seedclaimer committed
185
186
                    max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
                    max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
187
188
189
190
191
                    target_h = ((max_h + 32 - 1) // 32) * 32
                    target_w = ((max_w + 32 - 1) // 32) * 32

                    # 对所有图像进行padding到统一尺寸
                    batch_images = []
seedclaimer's avatar
seedclaimer committed
192
193
                    for crop_info in group_crops:
                        img = crop_info[0]
194
195
196
197
198
199
200
201
202
                        h, w = img.shape[:2]
                        # 创建目标尺寸的白色背景
                        padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
                        # 将原图像粘贴到左上角
                        padded_img[:h, :w] = img
                        batch_images.append(padded_img)

                    # 批处理检测
                    batch_size = min(len(batch_images), self.batch_ratio * 16)  # 增加批处理大小
Xiaomeng Zhao's avatar
Xiaomeng Zhao committed
203
                    # logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
204
205
206
207
208
209
                    batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)

                    # 处理批处理结果
                    for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info

seedclaimer's avatar
seedclaimer committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                        if dt_boxes is not None and len(dt_boxes) > 0:
                            # 直接应用原始OCR流程中的关键处理步骤
                            from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
                                merge_det_boxes, update_det_boxes, sorted_boxes
                            )

                            # 1. 排序检测框
                            if len(dt_boxes) > 0:
                                dt_boxes_sorted = sorted_boxes(dt_boxes)
                            else:
                                dt_boxes_sorted = []

                            # 2. 合并相邻检测框
                            if dt_boxes_sorted:
                                dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
                            else:
                                dt_boxes_merged = []

                            # 3. 根据公式位置更新检测框(关键步骤!)
                            if dt_boxes_merged and adjusted_mfdetrec_res:
                                dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
                            else:
                                dt_boxes_final = dt_boxes_merged

                            # 构造OCR结果格式
                            ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

                            if ocr_res:
                                ocr_result_list = get_ocr_result_list(
                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
                                )

                                if res["category_id"] == 3:
                                    # ocr_result_list中所有bbox的面积之和
                                    ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
                                    # 求ocr_res_area和res的面积的比值
                                    res_area = get_coords_and_area(res)[4]
                                    if res_area > 0:
                                        ratio = ocr_res_area / res_area
                                        if ratio > 0.25:
                                            res["category_id"] = 1
                                        else:
                                            continue

                                ocr_res_list_dict['layout_res'].extend(ocr_result_list)
        else:
            # 原始单张处理模式
            for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
                # Process each area that requires OCR processing
                _lang = ocr_res_list_dict['lang']
                # Get OCR results for this language's images
                atom_model_manager = AtomModelSingleton()
                ocr_model = atom_model_manager.get_atom_model(
                    atom_model_name='ocr',
                    ocr_show_log=False,
                    det_db_box_thresh=0.3,
                    lang=_lang
267
                )
268
269
270
271
272
273
274
                for res in ocr_res_list_dict['ocr_res_list']:
                    new_image, useful_list = crop_img(
                        res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
                    )
                    adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                        ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                    )
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                    # OCR-det
                    new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
                    ocr_res = ocr_model.ocr(
                        new_image, mfd_res=adjusted_mfdetrec_res, rec=False
                    )[0]

                    # Integration results
                    if ocr_res:
                        ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)

                        if res["category_id"] == 3:
                            # ocr_result_list中所有bbox的面积之和
                            ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
                            # 求ocr_res_area和res的面积的比值
                            res_area = get_coords_and_area(res)[4]
                            if res_area > 0:
                                ratio = ocr_res_area / res_area
                                if ratio > 0.25:
                                    res["category_id"] = 1
                                else:
                                    continue

                        ocr_res_list_dict['layout_res'].extend(ocr_result_list)
299
300

            # det_count += len(ocr_res_list_dict['ocr_res_list'])
301
302
303
304
305
306
307
        # logger.info(f'ocr-det time: {round(time.time()-det_start, 2)}, image num: {det_count}')


        # 表格识别 table recognition
        if self.model.apply_table:
            table_start = time.time()
            # for table_res_list_dict in table_res_list_all_page:
308
309
            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
                _lang = table_res_dict['lang']
310
311
312
313
314
315
316
                atom_model_manager = AtomModelSingleton()
                table_model = atom_model_manager.get_atom_model(
                    atom_model_name='table',
                    table_model_name='rapid_table',
                    table_model_path='',
                    table_max_time=400,
                    device='cpu',
317
                    lang=_lang,
318
319
                    table_sub_model_name='slanet_plus'
                )
320
321
322
323
324
325
326
327
                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
                # 判断是否返回正常
                if html_code:
                    expected_ending = html_code.strip().endswith(
                        '</html>'
                    ) or html_code.strip().endswith('</table>')
                    if expected_ending:
                        table_res_dict['table_res']['html'] = html_code
328
329
                    else:
                        logger.warning(
330
                            'table recognition processing fails, not found expected HTML table end'
331
                        )
332
333
334
335
336
                else:
                    logger.warning(
                        'table recognition processing fails, not get html return'
                    )
            # logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
337

338
339
340
341
        # Create dictionaries to store items by language
        need_ocr_lists_by_lang = {}  # Dict of lists for each language
        img_crop_lists_by_lang = {}  # Dict of lists for each language

342
343
344
        for layout_res in images_layout_res:
            for layout_res_item in layout_res:
                if layout_res_item['category_id'] in [15]:
345
346
347
348
349
350
351
352
353
354
355
356
357
                    if 'np_img' in layout_res_item and 'lang' in layout_res_item:
                        lang = layout_res_item['lang']

                        # Initialize lists for this language if not exist
                        if lang not in need_ocr_lists_by_lang:
                            need_ocr_lists_by_lang[lang] = []
                            img_crop_lists_by_lang[lang] = []

                        # Add to the appropriate language-specific lists
                        need_ocr_lists_by_lang[lang].append(layout_res_item)
                        img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])

                        # Remove the fields after adding to lists
358
                        layout_res_item.pop('np_img')
359
360
361
362
363
364
365
366
367
368
369
370
371
372
                        layout_res_item.pop('lang')


        if len(img_crop_lists_by_lang) > 0:

            # Process OCR by language
            rec_time = 0
            rec_start = time.time()
            total_processed = 0

            # Process each language separately
            for lang, img_crop_list in img_crop_lists_by_lang.items():
                if len(img_crop_list) > 0:
                    # Get OCR results for this language's images
373
374
375
376
377
378
379
                    atom_model_manager = AtomModelSingleton()
                    ocr_model = atom_model_manager.get_atom_model(
                        atom_model_name='ocr',
                        ocr_show_log=False,
                        det_db_box_thresh=0.3,
                        lang=lang
                    )
380
                    ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
381
382
383

                    # Verify we have matching counts
                    assert len(ocr_res_list) == len(
384
                        need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
385
386

                    # Process OCR results for this language
387
                    for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
388
389
                        ocr_text, ocr_score = ocr_res_list[index]
                        layout_res_item['text'] = ocr_text
390
                        layout_res_item['score'] = float(f"{ocr_score:.3f}")
391
392

                    total_processed += len(img_crop_list)
393

394
            rec_time += time.time() - rec_start
395
            # logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
396
397
398



399
        return images_layout_res